From dd3b5455c61bddce96a94c2ce8f5d76ed4948ea1 Mon Sep 17 00:00:00 2001 From: Rahul Tanwani Date: Sun, 28 Feb 2016 23:16:34 -0800 Subject: [PATCH 01/28] [SPARK-13309][SQL] Fix type inference issue with CSV data Fix type inference issue for sparse CSV data - https://issues.apache.org/jira/browse/SPARK-13309 Author: Rahul Tanwani Closes #11194 from tanwanirahul/master. --- .../datasources/csv/CSVInferSchema.scala | 18 +++++++++--------- sql/core/src/test/resources/simple_sparse.csv | 5 +++++ .../datasources/csv/CSVInferSchemaSuite.scala | 5 +++++ .../execution/datasources/csv/CSVSuite.scala | 14 +++++++++++++- 4 files changed, 32 insertions(+), 10 deletions(-) create mode 100644 sql/core/src/test/resources/simple_sparse.csv diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala index ace8cd7ad864..7f1ed28046b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala @@ -29,7 +29,6 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion import org.apache.spark.sql.types._ - private[csv] object CSVInferSchema { /** @@ -48,7 +47,11 @@ private[csv] object CSVInferSchema { tokenRdd.aggregate(startType)(inferRowType(nullValue), mergeRowTypes) val structFields = header.zip(rootTypes).map { case (thisHeader, rootType) => - StructField(thisHeader, rootType, nullable = true) + val dType = rootType match { + case _: NullType => StringType + case other => other + } + StructField(thisHeader, dType, nullable = true) } StructType(structFields) @@ -65,12 +68,8 @@ private[csv] object CSVInferSchema { } def mergeRowTypes(first: Array[DataType], second: Array[DataType]): Array[DataType] = { - first.zipAll(second, NullType, NullType).map { case ((a, b)) => - val tpe = findTightestCommonType(a, b).getOrElse(StringType) - tpe match { - case _: NullType => StringType - case other => other - } + first.zipAll(second, NullType, NullType).map { case (a, b) => + findTightestCommonType(a, b).getOrElse(NullType) } } @@ -140,6 +139,8 @@ private[csv] object CSVInferSchema { case (t1, t2) if t1 == t2 => Some(t1) case (NullType, t1) => Some(t1) case (t1, NullType) => Some(t1) + case (StringType, t2) => Some(StringType) + case (t1, StringType) => Some(StringType) // Promote numeric types to the highest of the two and all numeric types to unlimited decimal case (t1, t2) if Seq(t1, t2).forall(numericPrecedence.contains) => @@ -150,7 +151,6 @@ private[csv] object CSVInferSchema { } } - private[csv] object CSVTypeCast { /** diff --git a/sql/core/src/test/resources/simple_sparse.csv b/sql/core/src/test/resources/simple_sparse.csv new file mode 100644 index 000000000000..02d29cabf95f --- /dev/null +++ b/sql/core/src/test/resources/simple_sparse.csv @@ -0,0 +1,5 @@ +A,B,C,D +1,,, +,1,, +,,1, +,,,1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala index a1796f132600..412f1b89beee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala @@ -68,4 +68,9 @@ class InferSchemaSuite extends SparkFunSuite { assert(CSVInferSchema.inferField(DoubleType, "\\N", "\\N") == DoubleType) assert(CSVInferSchema.inferField(TimestampType, "\\N", "\\N") == TimestampType) } + + test("Merging Nulltypes should yeild Nulltype.") { + val mergedNullTypes = CSVInferSchema.mergeRowTypes(Array(NullType), Array(NullType)) + assert(mergedNullTypes.deep == Array(NullType).deep) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 7671bc106610..5d57d77ab054 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -37,6 +37,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { private val emptyFile = "empty.csv" private val commentsFile = "comments.csv" private val disableCommentsFile = "disable_comments.csv" + private val simpleSparseFile = "simple_sparse.csv" private def testFile(fileName: String): String = { Thread.currentThread().getContextClassLoader.getResource(fileName).toString @@ -233,7 +234,6 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { assert(result.schema.fieldNames.size === 1) } - test("DDL test with empty file") { sqlContext.sql(s""" |CREATE TEMPORARY TABLE carsTable @@ -396,4 +396,16 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { verifyCars(carsCopy, withHeader = true) } } + + test("Schema inference correctly identifies the datatype when data is sparse.") { + val df = sqlContext.read + .format("csv") + .option("header", "true") + .option("inferSchema", "true") + .load(testFile(simpleSparseFile)) + + assert( + df.schema.fields.map(field => field.dataType).deep == + Array(IntegerType, IntegerType, IntegerType, IntegerType).deep) + } } From d81a71357e24160244b6eeff028b0d9a4863becf Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 29 Feb 2016 00:55:51 -0800 Subject: [PATCH 02/28] [SPARK-13545][MLLIB][PYSPARK] Make MLlib LogisticRegressionWithLBFGS's default parameters consistent in Scala and Python ## What changes were proposed in this pull request? * The default value of ```regParam``` of PySpark MLlib ```LogisticRegressionWithLBFGS``` should be consistent with Scala which is ```0.0```. (This is also consistent with ML ```LogisticRegression```.) * BTW, if we use a known updater(L1 or L2) for binary classification, ```LogisticRegressionWithLBFGS``` will call the ML implementation. We should update the API doc to clarifying ```numCorrections``` will have no effect if we fall into that route. * Make a pass for all parameters of ```LogisticRegressionWithLBFGS```, others are set properly. cc mengxr dbtsai ## How was this patch tested? No new tests, it should pass all current tests. Author: Yanbo Liang Closes #11424 from yanboliang/spark-13545. --- .../spark/mllib/classification/LogisticRegression.scala | 4 ++++ python/pyspark/mllib/classification.py | 8 +++++--- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index c3882606d7db..f807b5683c39 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -408,6 +408,10 @@ class LogisticRegressionWithLBFGS * defaults to the mllib implementation. If more than two classes * or feature scaling is disabled, always uses mllib implementation. * Uses user provided weights. + * + * In the ml LogisticRegression implementation, the number of corrections + * used in the LBFGS update can not be configured. So `optimizer.setNumCorrections()` + * will have no effect if we fall into that route. */ override def run(input: RDD[LabeledPoint], initialWeights: Vector): LogisticRegressionModel = { run(input, initialWeights, userSuppliedWeights = true) diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index b4d54ef61b0e..53a0df27cace 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -326,7 +326,7 @@ class LogisticRegressionWithLBFGS(object): """ @classmethod @since('1.2.0') - def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType="l2", + def train(cls, data, iterations=100, initialWeights=None, regParam=0.0, regType="l2", intercept=False, corrections=10, tolerance=1e-6, validateData=True, numClasses=2): """ Train a logistic regression model on the given data. @@ -341,7 +341,7 @@ def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType (default: None) :param regParam: The regularizer parameter. - (default: 0.01) + (default: 0.0) :param regType: The type of regularizer used for training our model. Allowed values: @@ -356,7 +356,9 @@ def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType (default: False) :param corrections: The number of corrections used in the LBFGS update. - (default: 10) + If a known updater is used for binary classification, + it calls the ml implementation and this parameter will + have no effect. (default: 10) :param tolerance: The convergence tolerance of iterations for L-BFGS. (default: 1e-6) From 99fe8993f51d3c72cd95eb0825b090dd4d4cd2cd Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Mon, 29 Feb 2016 12:08:37 +0000 Subject: [PATCH 03/28] =?UTF-8?q?[SPARK-12994][CORE]=20It=20is=20not=20nec?= =?UTF-8?q?essary=20to=20create=20ExecutorAllocationM=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …anager in local mode Author: Jeff Zhang Closes #10914 from zjffdu/SPARK-12994. --- .../scala/org/apache/spark/SparkContext.scala | 6 +----- .../scala/org/apache/spark/util/Utils.scala | 19 +++++++++++++++++-- .../org/apache/spark/util/UtilsSuite.scala | 3 +++ 3 files changed, 21 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index a1fa266e183e..0e8b735b923b 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -244,7 +244,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli private[spark] def eventLogDir: Option[URI] = _eventLogDir private[spark] def eventLogCodec: Option[String] = _eventLogCodec - def isLocal: Boolean = (master == "local" || master.startsWith("local[")) + def isLocal: Boolean = Utils.isLocalMaster(_conf) /** * @return true if context is stopped or in the midst of stopping. @@ -526,10 +526,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Optionally scale number of executors dynamically based on workload. Exposed for testing. val dynamicAllocationEnabled = Utils.isDynamicAllocationEnabled(_conf) - if (!dynamicAllocationEnabled && _conf.getBoolean("spark.dynamicAllocation.enabled", false)) { - logWarning("Dynamic Allocation and num executors both set, thus dynamic allocation disabled.") - } - _executorAllocationManager = if (dynamicAllocationEnabled) { Some(new ExecutorAllocationManager(this, listenerBus, _conf)) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index e0c9bf02a1a2..6103a10ccc50 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2195,6 +2195,16 @@ private[spark] object Utils extends Logging { isInDirectory(parent, child.getParentFile) } + + /** + * + * @return whether it is local mode + */ + def isLocalMaster(conf: SparkConf): Boolean = { + val master = conf.get("spark.master", "") + master == "local" || master.startsWith("local[") + } + /** * Return whether dynamic allocation is enabled in the given conf * Dynamic allocation and explicitly setting the number of executors are inherently @@ -2202,8 +2212,13 @@ private[spark] object Utils extends Logging { * the latter should override the former (SPARK-9092). */ def isDynamicAllocationEnabled(conf: SparkConf): Boolean = { - conf.getBoolean("spark.dynamicAllocation.enabled", false) && - conf.getInt("spark.executor.instances", 0) == 0 + val numExecutor = conf.getInt("spark.executor.instances", 0) + val dynamicAllocationEnabled = conf.getBoolean("spark.dynamicAllocation.enabled", false) + if (numExecutor != 0 && dynamicAllocationEnabled) { + logWarning("Dynamic Allocation and num executors both set, thus dynamic allocation disabled.") + } + numExecutor == 0 && dynamicAllocationEnabled && + (!isLocalMaster(conf) || conf.getBoolean("spark.dynamicAllocation.testing", false)) } def tryWithResource[R <: Closeable, T](createResource: => R)(f: R => T): T = { diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 7c6778b06546..412c0ac9d9be 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -722,6 +722,7 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { test("isDynamicAllocationEnabled") { val conf = new SparkConf() + conf.set("spark.master", "yarn-client") assert(Utils.isDynamicAllocationEnabled(conf) === false) assert(Utils.isDynamicAllocationEnabled( conf.set("spark.dynamicAllocation.enabled", "false")) === false) @@ -731,6 +732,8 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { conf.set("spark.executor.instances", "1")) === false) assert(Utils.isDynamicAllocationEnabled( conf.set("spark.executor.instances", "0")) === true) + assert(Utils.isDynamicAllocationEnabled(conf.set("spark.master", "local")) === false) + assert(Utils.isDynamicAllocationEnabled(conf.set("spark.dynamicAllocation.testing", "true"))) } test("encodeFileNameToURIRawPath") { From 236e3c8fbc887e4da4f143cbf533f016f21c10d4 Mon Sep 17 00:00:00 2001 From: vijaykiran Date: Mon, 29 Feb 2016 15:52:41 +0200 Subject: [PATCH 04/28] [SPARK-12633][PYSPARK] [DOC] PySpark regression parameter desc to consistent format Part of task for [SPARK-11219](https://issues.apache.org/jira/browse/SPARK-11219) to make PySpark MLlib parameter description formatting consistent. This is for the regression module. Also, updated 2 params in classification to read as `Supported values:` to be consistent. closes #10600 Author: vijaykiran Author: Bryan Cutler Closes #11404 from BryanCutler/param-desc-consistent-regression-SPARK-12633. --- python/pyspark/mllib/classification.py | 4 +- python/pyspark/mllib/regression.py | 326 +++++++++++++------------ 2 files changed, 166 insertions(+), 164 deletions(-) diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index 53a0df27cace..57106f8690a7 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -294,7 +294,7 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, (default: 0.01) :param regType: The type of regularizer used for training our model. - Allowed values: + Supported values: - "l1" for using L1 regularization - "l2" for using L2 regularization (default) @@ -344,7 +344,7 @@ def train(cls, data, iterations=100, initialWeights=None, regParam=0.0, regType= (default: 0.0) :param regType: The type of regularizer used for training our model. - Allowed values: + Supported values: - "l1" for using L1 regularization - "l2" for using L2 regularization (default) diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 4dd7083d79c8..3b77a6200054 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -37,10 +37,11 @@ class LabeledPoint(object): """ Class that represents the features and labels of a data point. - :param label: Label for this data point. - :param features: Vector of features for this point (NumPy array, - list, pyspark.mllib.linalg.SparseVector, or scipy.sparse - column matrix) + :param label: + Label for this data point. + :param features: + Vector of features for this point (NumPy array, list, + pyspark.mllib.linalg.SparseVector, or scipy.sparse column matrix). Note: 'label' and 'features' are accessible as class attributes. @@ -66,8 +67,10 @@ class LinearModel(object): """ A linear model that has a vector of coefficients and an intercept. - :param weights: Weights computed for every feature. - :param intercept: Intercept computed for this model. + :param weights: + Weights computed for every feature. + :param intercept: + Intercept computed for this model. .. versionadded:: 0.9.0 """ @@ -217,19 +220,8 @@ def _regression_train_wrapper(train_func, modelClass, data, initial_weights): class LinearRegressionWithSGD(object): """ - Train a linear regression model with no regularization using Stochastic Gradient Descent. - This solves the least squares regression formulation - - f(weights) = 1/n ||A weights-y||^2 - - which is the mean squared error. - Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with - its corresponding right hand side label y. - See also the documentation for the precise formulation. - .. versionadded:: 0.9.0 """ - @classmethod @since("0.9.0") def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, @@ -237,47 +229,52 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, validateData=True, convergenceTol=0.001): """ Train a linear regression model using Stochastic Gradient - Descent (SGD). - This solves the least squares regression formulation - - f(weights) = 1/(2n) ||A weights - y||^2, - - which is the mean squared error. - Here the data matrix has n rows, and the input RDD holds the - set of rows of A, each with its corresponding right hand side - label y. See also the documentation for the precise formulation. - - :param data: The training data, an RDD of - LabeledPoint. - :param iterations: The number of iterations - (default: 100). - :param step: The step parameter used in SGD - (default: 1.0). - :param miniBatchFraction: Fraction of data to be used for each - SGD iteration (default: 1.0). - :param initialWeights: The initial weights (default: None). - :param regParam: The regularizer parameter - (default: 0.0). - :param regType: The type of regularizer used for - training our model. - - :Allowed values: - - "l1" for using L1 regularization (lasso), - - "l2" for using L2 regularization (ridge), - - None for no regularization - - (default: None) - - :param intercept: Boolean parameter which indicates the - use or not of the augmented representation - for training data (i.e. whether bias - features are activated or not, - default: False). - :param validateData: Boolean parameter which indicates if - the algorithm should validate data - before training. (default: True) - :param convergenceTol: A condition which decides iteration termination. - (default: 0.001) + Descent (SGD). This solves the least squares regression + formulation + + f(weights) = 1/(2n) ||A weights - y||^2 + + which is the mean squared error. Here the data matrix has n rows, + and the input RDD holds the set of rows of A, each with its + corresponding right hand side label y. + See also the documentation for the precise formulation. + + :param data: + The training data, an RDD of LabeledPoint. + :param iterations: + The number of iterations. + (default: 100) + :param step: + The step parameter used in SGD. + (default: 1.0) + :param miniBatchFraction: + Fraction of data to be used for each SGD iteration. + (default: 1.0) + :param initialWeights: + The initial weights. + (default: None) + :param regParam: + The regularizer parameter. + (default: 0.0) + :param regType: + The type of regularizer used for training our model. + Supported values: + + - "l1" for using L1 regularization + - "l2" for using L2 regularization + - None for no regularization (default) + :param intercept: + Boolean parameter which indicates the use or not of the + augmented representation for training data (i.e., whether bias + features are activated or not). + (default: False) + :param validateData: + Boolean parameter which indicates if the algorithm should + validate data before training. + (default: True) + :param convergenceTol: + A condition which decides iteration termination. + (default: 0.001) """ def train(rdd, i): return callMLlibFunc("trainLinearRegressionModelWithSGD", rdd, int(iterations), @@ -368,56 +365,53 @@ def load(cls, sc, path): class LassoWithSGD(object): """ - Train a regression model with L1-regularization using Stochastic Gradient Descent. - This solves the L1-regularized least squares regression formulation - - f(weights) = 1/2n ||A weights-y||^2 + regParam ||weights||_1 - - Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with - its corresponding right hand side label y. - See also the documentation for the precise formulation. - .. versionadded:: 0.9.0 """ - @classmethod @since("0.9.0") def train(cls, data, iterations=100, step=1.0, regParam=0.01, miniBatchFraction=1.0, initialWeights=None, intercept=False, validateData=True, convergenceTol=0.001): """ - Train a regression model with L1-regularization using - Stochastic Gradient Descent. - This solves the l1-regularized least squares regression - formulation - - f(weights) = 1/(2n) ||A weights - y||^2 + regParam ||weights||_1. - - Here the data matrix has n rows, and the input RDD holds the - set of rows of A, each with its corresponding right hand side - label y. See also the documentation for the precise formulation. - - :param data: The training data, an RDD of - LabeledPoint. - :param iterations: The number of iterations - (default: 100). - :param step: The step parameter used in SGD - (default: 1.0). - :param regParam: The regularizer parameter - (default: 0.01). - :param miniBatchFraction: Fraction of data to be used for each - SGD iteration (default: 1.0). - :param initialWeights: The initial weights (default: None). - :param intercept: Boolean parameter which indicates the - use or not of the augmented representation - for training data (i.e. whether bias - features are activated or not, - default: False). - :param validateData: Boolean parameter which indicates if - the algorithm should validate data - before training. (default: True) - :param convergenceTol: A condition which decides iteration termination. - (default: 0.001) + Train a regression model with L1-regularization using Stochastic + Gradient Descent. This solves the l1-regularized least squares + regression formulation + + f(weights) = 1/(2n) ||A weights - y||^2 + regParam ||weights||_1 + + Here the data matrix has n rows, and the input RDD holds the set + of rows of A, each with its corresponding right hand side label y. + See also the documentation for the precise formulation. + + :param data: + The training data, an RDD of LabeledPoint. + :param iterations: + The number of iterations. + (default: 100) + :param step: + The step parameter used in SGD. + (default: 1.0) + :param regParam: + The regularizer parameter. + (default: 0.01) + :param miniBatchFraction: + Fraction of data to be used for each SGD iteration. + (default: 1.0) + :param initialWeights: + The initial weights. + (default: None) + :param intercept: + Boolean parameter which indicates the use or not of the + augmented representation for training data (i.e. whether bias + features are activated or not). + (default: False) + :param validateData: + Boolean parameter which indicates if the algorithm should + validate data before training. + (default: True) + :param convergenceTol: + A condition which decides iteration termination. + (default: 0.001) """ def train(rdd, i): return callMLlibFunc("trainLassoModelWithSGD", rdd, int(iterations), float(step), @@ -508,56 +502,53 @@ def load(cls, sc, path): class RidgeRegressionWithSGD(object): """ - Train a regression model with L2-regularization using Stochastic Gradient Descent. - This solves the L2-regularized least squares regression formulation - - f(weights) = 1/2n ||A weights-y||^2 + regParam/2 ||weights||^2 - - Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with - its corresponding right hand side label y. - See also the documentation for the precise formulation. - .. versionadded:: 0.9.0 """ - @classmethod @since("0.9.0") def train(cls, data, iterations=100, step=1.0, regParam=0.01, miniBatchFraction=1.0, initialWeights=None, intercept=False, validateData=True, convergenceTol=0.001): """ - Train a regression model with L2-regularization using - Stochastic Gradient Descent. - This solves the l2-regularized least squares regression - formulation - - f(weights) = 1/(2n) ||A weights - y||^2 + regParam/2 ||weights||^2. - - Here the data matrix has n rows, and the input RDD holds the - set of rows of A, each with its corresponding right hand side - label y. See also the documentation for the precise formulation. - - :param data: The training data, an RDD of - LabeledPoint. - :param iterations: The number of iterations - (default: 100). - :param step: The step parameter used in SGD - (default: 1.0). - :param regParam: The regularizer parameter - (default: 0.01). - :param miniBatchFraction: Fraction of data to be used for each - SGD iteration (default: 1.0). - :param initialWeights: The initial weights (default: None). - :param intercept: Boolean parameter which indicates the - use or not of the augmented representation - for training data (i.e. whether bias - features are activated or not, - default: False). - :param validateData: Boolean parameter which indicates if - the algorithm should validate data - before training. (default: True) - :param convergenceTol: A condition which decides iteration termination. - (default: 0.001) + Train a regression model with L2-regularization using Stochastic + Gradient Descent. This solves the l2-regularized least squares + regression formulation + + f(weights) = 1/(2n) ||A weights - y||^2 + regParam/2 ||weights||^2 + + Here the data matrix has n rows, and the input RDD holds the set + of rows of A, each with its corresponding right hand side label y. + See also the documentation for the precise formulation. + + :param data: + The training data, an RDD of LabeledPoint. + :param iterations: + The number of iterations. + (default: 100) + :param step: + The step parameter used in SGD. + (default: 1.0) + :param regParam: + The regularizer parameter. + (default: 0.01) + :param miniBatchFraction: + Fraction of data to be used for each SGD iteration. + (default: 1.0) + :param initialWeights: + The initial weights. + (default: None) + :param intercept: + Boolean parameter which indicates the use or not of the + augmented representation for training data (i.e. whether bias + features are activated or not). + (default: False) + :param validateData: + Boolean parameter which indicates if the algorithm should + validate data before training. + (default: True) + :param convergenceTol: + A condition which decides iteration termination. + (default: 0.001) """ def train(rdd, i): return callMLlibFunc("trainRidgeModelWithSGD", rdd, int(iterations), float(step), @@ -572,12 +563,14 @@ class IsotonicRegressionModel(Saveable, Loader): """ Regression model for isotonic regression. - :param boundaries: Array of boundaries for which predictions are - known. Boundaries must be sorted in increasing order. - :param predictions: Array of predictions associated to the - boundaries at the same index. Results of isotonic - regression and therefore monotone. - :param isotonic: indicates whether this is isotonic or antitonic. + :param boundaries: + Array of boundaries for which predictions are known. Boundaries + must be sorted in increasing order. + :param predictions: + Array of predictions associated to the boundaries at the same + index. Results of isotonic regression and therefore monotone. + :param isotonic: + Indicates whether this is isotonic or antitonic. >>> data = [(1, 0, 1), (2, 1, 1), (3, 2, 1), (1, 3, 1), (6, 4, 1), (17, 5, 1), (16, 6, 1)] >>> irm = IsotonicRegression.train(sc.parallelize(data)) @@ -628,7 +621,8 @@ def predict(self, x): values with the same boundary then the same rules as in 2) are used. - :param x: Feature or RDD of Features to be labeled. + :param x: + Feature or RDD of Features to be labeled. """ if isinstance(x, RDD): return x.map(lambda v: self.predict(v)) @@ -657,8 +651,8 @@ def load(cls, sc, path): class IsotonicRegression(object): """ Isotonic regression. - Currently implemented using parallelized pool adjacent violators algorithm. - Only univariate (single feature) algorithm supported. + Currently implemented using parallelized pool adjacent violators + algorithm. Only univariate (single feature) algorithm supported. Sequential PAV implementation based on: @@ -684,8 +678,11 @@ def train(cls, data, isotonic=True): """ Train a isotonic regression model on the given data. - :param data: RDD of (label, feature, weight) tuples. - :param isotonic: Whether this is isotonic or antitonic. + :param data: + RDD of (label, feature, weight) tuples. + :param isotonic: + Whether this is isotonic (which is default) or antitonic. + (default: True) """ boundaries, predictions = callMLlibFunc("trainIsotonicRegressionModel", data.map(_convert_to_vector), bool(isotonic)) @@ -721,9 +718,11 @@ def _validate(self, dstream): @since("1.5.0") def predictOn(self, dstream): """ - Make predictions on a dstream. + Use the model to make predictions on batches of data from a + DStream. - :return: Transformed dstream object. + :return: + DStream containing predictions. """ self._validate(dstream) return dstream.map(lambda x: self._model.predict(x)) @@ -731,9 +730,11 @@ def predictOn(self, dstream): @since("1.5.0") def predictOnValues(self, dstream): """ - Make predictions on a keyed dstream. + Use the model to make predictions on the values of a DStream and + carry over its keys. - :return: Transformed dstream object. + :return: + DStream containing the input keys and the predictions as values. """ self._validate(dstream) return dstream.mapValues(lambda x: self._model.predict(x)) @@ -742,14 +743,15 @@ def predictOnValues(self, dstream): @inherit_doc class StreamingLinearRegressionWithSGD(StreamingLinearAlgorithm): """ - Train or predict a linear regression model on streaming data. Training uses - Stochastic Gradient Descent to update the model based on each new batch of - incoming data from a DStream (see `LinearRegressionWithSGD` for model equation). + Train or predict a linear regression model on streaming data. + Training uses Stochastic Gradient Descent to update the model + based on each new batch of incoming data from a DStream + (see `LinearRegressionWithSGD` for model equation). Each batch of data is assumed to be an RDD of LabeledPoints. The number of data points per batch can vary, but the number - of features must be constant. An initial weight - vector must be provided. + of features must be constant. An initial weight vector must + be provided. :param stepSize: Step size for each iteration of gradient descent. From 2f91f5ac0d2a5932169b245e3ef3e19849131277 Mon Sep 17 00:00:00 2001 From: zhuol Date: Mon, 29 Feb 2016 08:37:33 -0600 Subject: [PATCH 05/28] [SPARK-13481] Desc order of appID by default for history server page. ## What changes were proposed in this pull request? Now by default, it shows as ascending order of appId. We might prefer to display as descending order by default, which will show the latest application at the top. ## How was this patch tested? Manual tested. See screenshot below: ![desc-sort](https://cloud.githubusercontent.com/assets/11683054/13307473/102f4cf8-db31-11e5-8dd5-391edbf32f0d.png) Author: zhuol Closes #11357 from zhuoliu/13481. --- .../main/resources/org/apache/spark/ui/static/historypage.js | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage.js b/core/src/main/resources/org/apache/spark/ui/static/historypage.js index 6195916195e3..167c8020850d 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage.js @@ -149,7 +149,8 @@ $(document).ready(function() { {name: 'seventh'}, {name: 'eighth'}, ], - "autoWidth": false + "autoWidth": false, + "order": [[ 0, "desc" ]] }; var rowGroupConf = { From ac5c635281aa796547e31e20c10d2483469294ee Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Mon, 29 Feb 2016 14:51:27 +0000 Subject: [PATCH 06/28] [SPARK-13506][MLLIB] Fix the wrong parameter in R code comment in AssociationRulesSuite JIRA: https://issues.apache.org/jira/browse/SPARK-13506 ## What changes were proposed in this pull request? just chang R Snippet Comment in AssociationRulesSuite ## How was this patch tested? unit test passsed Author: Zheng RuiFeng Closes #11387 from zhengruifeng/ars. --- .../org/apache/spark/mllib/fpm/AssociationRulesSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala index 77a2773c36f5..dcb1f398b04b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala @@ -42,6 +42,7 @@ class AssociationRulesSuite extends SparkFunSuite with MLlibTestSparkContext { .collect() /* Verify results using the `R` code: + library(arules) transactions = as(sapply( list("r z h k p", "z y x w v u t s", @@ -52,7 +53,7 @@ class AssociationRulesSuite extends SparkFunSuite with MLlibTestSparkContext { FUN=function(x) strsplit(x," ",fixed=TRUE)), "transactions") ars = apriori(transactions, - parameter = list(support = 0.0, confidence = 0.5, target="rules", minlen=2)) + parameter = list(support = 0.5, confidence = 0.9, target="rules", minlen=2)) arsDF = as(ars, "data.frame") arsDF$support = arsDF$support * length(transactions) names(arsDF)[names(arsDF) == "support"] = "freq" From 916fc34f98dd731f607d9b3ed657bad6cc30df2c Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 1 Mar 2016 01:07:45 +0800 Subject: [PATCH 07/28] [SPARK-13540][SQL] Supports using nested classes within Scala objects as Dataset element type ## What changes were proposed in this pull request? Nested classes defined within Scala objects are translated into Java static nested classes. Unlike inner classes, they don't need outer scopes. But the analyzer still thinks that an outer scope is required. This PR fixes this issue simply by checking whether a nested class is static before looking up its outer scope. ## How was this patch tested? A test case is added to `DatasetSuite`. It checks contents of a Dataset whose element type is a nested class declared in a Scala object. Author: Cheng Lian Closes #11421 from liancheng/spark-13540-object-as-outer-scope. --- .../apache/spark/sql/catalyst/analysis/Analyzer.scala | 10 +++++++++- .../test/scala/org/apache/spark/sql/DatasetSuite.scala | 10 ++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 23e4709bbd88..876aa0eae0e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.analysis +import java.lang.reflect.Modifier + import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException @@ -559,7 +561,13 @@ class Analyzer( } resolveExpression(unbound, LocalRelation(attributes), throws = true) transform { - case n: NewInstance if n.outerPointer.isEmpty && n.cls.isMemberClass => + case n: NewInstance + // If this is an inner class of another class, register the outer object in `OuterScopes`. + // Note that static inner classes (e.g., inner classes within Scala objects) don't need + // outer pointer registration. + if n.outerPointer.isEmpty && + n.cls.isMemberClass && + !Modifier.isStatic(n.cls.getModifiers) => val outer = OuterScopes.outerScopes.get(n.cls.getDeclaringClass.getName) if (outer == null) { throw new AnalysisException( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 14fc37b64aa3..33df6375e3aa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -621,12 +621,22 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ds.filter(_ => true), Some(1), Some(2), Some(3)) } + + test("SPARK-13540 Dataset of nested class defined in Scala object") { + checkAnswer( + Seq(OuterObject.InnerClass("foo")).toDS(), + OuterObject.InnerClass("foo")) + } } class OuterClass extends Serializable { case class InnerClass(a: String) } +object OuterObject { + case class InnerClass(a: String) +} + case class ClassData(a: String, b: Int) case class ClassData2(c: String, d: Int) case class ClassNullableData(a: String, b: Integer) From 02aa499dfb71bc9571bebb79e6383842e4f48143 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 29 Feb 2016 09:44:29 -0800 Subject: [PATCH 08/28] [SPARK-13509][SPARK-13507][SQL] Support for writing CSV with a single function call https://issues.apache.org/jira/browse/SPARK-13507 https://issues.apache.org/jira/browse/SPARK-13509 ## What changes were proposed in this pull request? This PR adds the support to write CSV data directly by a single call to the given path. Several unitests were added for each functionality. ## How was this patch tested? This was tested with unittests and with `dev/run_tests` for coding style Author: hyukjinkwon Author: Hyukjin Kwon Closes #11389 from HyukjinKwon/SPARK-13507-13509. --- python/pyspark/sql/readwriter.py | 50 +++++++++++++++++++ python/test_support/sql/ages.csv | 4 ++ .../apache/spark/sql/DataFrameWriter.scala | 23 +++++++++ .../datasources/json/JSONOptions.scala | 5 +- .../datasources/text/DefaultSource.scala | 5 +- .../execution/datasources/csv/CSVSuite.scala | 3 +- 6 files changed, 80 insertions(+), 10 deletions(-) create mode 100644 python/test_support/sql/ages.csv diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index b1453c637f79..7f5368d8bdbb 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -233,6 +233,23 @@ def text(self, paths): paths = [paths] return self._df(self._jreader.text(self._sqlContext._sc._jvm.PythonUtils.toSeq(paths))) + @since(2.0) + def csv(self, paths): + """Loads a CSV file and returns the result as a [[DataFrame]]. + + This function goes through the input once to determine the input schema. To avoid going + through the entire data once, specify the schema explicitly using [[schema]]. + + :param paths: string, or list of strings, for input path(s). + + >>> df = sqlContext.read.csv('python/test_support/sql/ages.csv') + >>> df.dtypes + [('C0', 'string'), ('C1', 'string')] + """ + if isinstance(paths, basestring): + paths = [paths] + return self._df(self._jreader.csv(self._sqlContext._sc._jvm.PythonUtils.toSeq(paths))) + @since(1.5) def orc(self, path): """Loads an ORC file, returning the result as a :class:`DataFrame`. @@ -448,6 +465,11 @@ def json(self, path, mode=None): * ``ignore``: Silently ignore this operation if data already exists. * ``error`` (default case): Throw an exception if data already exists. + You can set the following JSON-specific option(s) for writing JSON files: + * ``compression`` (default ``None``): compression codec to use when saving to file. + This can be one of the known case-insensitive shorten names + (``bzip2``, ``gzip``, ``lz4``, and ``snappy``). + >>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data')) """ self.mode(mode)._jwrite.json(path) @@ -476,11 +498,39 @@ def parquet(self, path, mode=None, partitionBy=None): def text(self, path): """Saves the content of the DataFrame in a text file at the specified path. + :param path: the path in any Hadoop supported file system + The DataFrame must have only one column that is of string type. Each row becomes a new line in the output file. + + You can set the following option(s) for writing text files: + * ``compression`` (default ``None``): compression codec to use when saving to file. + This can be one of the known case-insensitive shorten names + (``bzip2``, ``gzip``, ``lz4``, and ``snappy``). """ self._jwrite.text(path) + @since(2.0) + def csv(self, path, mode=None): + """Saves the content of the [[DataFrame]] in CSV format at the specified path. + + :param path: the path in any Hadoop supported file system + :param mode: specifies the behavior of the save operation when data already exists. + + * ``append``: Append contents of this :class:`DataFrame` to existing data. + * ``overwrite``: Overwrite existing data. + * ``ignore``: Silently ignore this operation if data already exists. + * ``error`` (default case): Throw an exception if data already exists. + + You can set the following CSV-specific option(s) for writing CSV files: + * ``compression`` (default ``None``): compression codec to use when saving to file. + This can be one of the known case-insensitive shorten names + (``bzip2``, ``gzip``, ``lz4``, and ``snappy``). + + >>> df.write.csv(os.path.join(tempfile.mkdtemp(), 'data')) + """ + self.mode(mode)._jwrite.csv(path) + @since(1.5) def orc(self, path, mode=None, partitionBy=None): """Saves the content of the :class:`DataFrame` in ORC format at the specified path. diff --git a/python/test_support/sql/ages.csv b/python/test_support/sql/ages.csv new file mode 100644 index 000000000000..18991feda788 --- /dev/null +++ b/python/test_support/sql/ages.csv @@ -0,0 +1,4 @@ +Joe,20 +Tom,30 +Hyukjin,25 + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index d6bdd3d82556..093504c765ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -453,6 +453,10 @@ final class DataFrameWriter private[sql](df: DataFrame) { * format("json").save(path) * }}} * + * You can set the following JSON-specific option(s) for writing JSON files: + *
  • `compression` (default `null`): compression codec to use when saving to file. This can be + * one of the known case-insensitive shorten names (`bzip2`, `gzip`, `lz4`, and `snappy`).
  • + * * @since 1.4.0 */ def json(path: String): Unit = format("json").save(path) @@ -492,10 +496,29 @@ final class DataFrameWriter private[sql](df: DataFrame) { * df.write().text("/path/to/output") * }}} * + * You can set the following option(s) for writing text files: + *
  • `compression` (default `null`): compression codec to use when saving to file. This can be + * one of the known case-insensitive shorten names (`bzip2`, `gzip`, `lz4`, and `snappy`).
  • + * * @since 1.6.0 */ def text(path: String): Unit = format("text").save(path) + /** + * Saves the content of the [[DataFrame]] in CSV format at the specified path. + * This is equivalent to: + * {{{ + * format("csv").save(path) + * }}} + * + * You can set the following CSV-specific option(s) for writing CSV files: + *
  • `compression` (default `null`): compression codec to use when saving to file. This can be + * one of the known case-insensitive shorten names (`bzip2`, `gzip`, `lz4`, and `snappy`).
  • + * + * @since 2.0.0 + */ + def csv(path: String): Unit = format("csv").save(path) + /////////////////////////////////////////////////////////////////////////////////////// // Builder pattern config options /////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala index 31a95ed46121..e59dbd6b3d43 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala @@ -48,10 +48,7 @@ private[sql] class JSONOptions( parameters.get("allowNonNumericNumbers").map(_.toBoolean).getOrElse(true) val allowBackslashEscapingAnyCharacter = parameters.get("allowBackslashEscapingAnyCharacter").map(_.toBoolean).getOrElse(false) - val compressionCodec = { - val name = parameters.get("compression").orElse(parameters.get("codec")) - name.map(CompressionCodecs.getCodecClassName) - } + val compressionCodec = parameters.get("compression").map(CompressionCodecs.getCodecClassName) /** Sets config options on a Jackson [[JsonFactory]]. */ def setJacksonOptions(factory: JsonFactory): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala index 60155b32349a..8f3f6335e428 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala @@ -115,10 +115,7 @@ private[sql] class TextRelation( /** Write path. */ override def prepareJobForWrite(job: Job): OutputWriterFactory = { val conf = job.getConfiguration - val compressionCodec = { - val name = parameters.get("compression").orElse(parameters.get("codec")) - name.map(CompressionCodecs.getCodecClassName) - } + val compressionCodec = parameters.get("compression").map(CompressionCodecs.getCodecClassName) compressionCodec.foreach { codec => CompressionCodecs.setCodecConfiguration(conf, codec) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 5d57d77ab054..3ecbb14f2ea6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -268,9 +268,8 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .load(testFile(carsFile)) cars.coalesce(1).write - .format("csv") .option("header", "true") - .save(csvDir) + .csv(csvDir) val carsCopy = sqlContext.read .format("csv") From bc65f60ef7c920db756bfe643f7edbdf3593a989 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 29 Feb 2016 10:10:04 -0800 Subject: [PATCH 09/28] [SPARK-13544][SQL] Rewrite/Propagate Constraints for Aliases in Aggregate #### What changes were proposed in this pull request? After analysis by Analyzer, two operators could have alias. They are `Project` and `Aggregate`. So far, we only rewrite and propagate constraints if `Alias` is defined in `Project`. This PR is to resolve this issue in `Aggregate`. #### How was this patch tested? Added a test case for `Aggregate` in `ConstraintPropagationSuite`. marmbrus sameeragarwal Author: gatorsmile Closes #11422 from gatorsmile/validConstraintsInUnaryNodes. --- .../catalyst/plans/logical/LogicalPlan.scala | 16 ++++++++++ .../plans/logical/basicOperators.scala | 30 +++++-------------- .../plans/ConstraintPropagationSuite.scala | 15 ++++++++++ 3 files changed, 38 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 8095083f336e..31e775d60f95 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -315,6 +315,22 @@ abstract class UnaryNode extends LogicalPlan { override def children: Seq[LogicalPlan] = child :: Nil + /** + * Generates an additional set of aliased constraints by replacing the original constraint + * expressions with the corresponding alias + */ + protected def getAliasedConstraints(projectList: Seq[NamedExpression]): Set[Expression] = { + projectList.flatMap { + case a @ Alias(e, _) => + child.constraints.map(_ transform { + case expr: Expression if expr.semanticEquals(e) => + a.toAttribute + }).union(Set(EqualNullSafe(e, a.toAttribute))) + case _ => + Set.empty[Expression] + }.toSet + } + override protected def validConstraints: Set[Expression] = child.constraints override def statistics: Statistics = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 5d2a65b716b2..e81a0f948746 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -51,25 +51,8 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend !expressions.exists(!_.resolved) && childrenResolved && !hasSpecialExpressions } - /** - * Generates an additional set of aliased constraints by replacing the original constraint - * expressions with the corresponding alias - */ - private def getAliasedConstraints: Set[Expression] = { - projectList.flatMap { - case a @ Alias(e, _) => - child.constraints.map(_ transform { - case expr: Expression if expr.semanticEquals(e) => - a.toAttribute - }).union(Set(EqualNullSafe(e, a.toAttribute))) - case _ => - Set.empty[Expression] - }.toSet - } - - override def validConstraints: Set[Expression] = { - child.constraints.union(getAliasedConstraints) - } + override def validConstraints: Set[Expression] = + child.constraints.union(getAliasedConstraints(projectList)) } /** @@ -126,9 +109,8 @@ case class Filter(condition: Expression, child: LogicalPlan) override def maxRows: Option[Long] = child.maxRows - override protected def validConstraints: Set[Expression] = { + override protected def validConstraints: Set[Expression] = child.constraints.union(splitConjunctivePredicates(condition).toSet) - } } abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { @@ -157,9 +139,8 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable) } - override protected def validConstraints: Set[Expression] = { + override protected def validConstraints: Set[Expression] = leftConstraints.union(rightConstraints) - } // Intersect are only resolved if they don't introduce ambiguous expression ids, // since the Optimizer will convert Intersect to Join. @@ -442,6 +423,9 @@ case class Aggregate( override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) override def maxRows: Option[Long] = child.maxRows + override def validConstraints: Set[Expression] = + child.constraints.union(getAliasedConstraints(aggregateExpressions)) + override def statistics: Statistics = { if (groupingExpressions.isEmpty) { Statistics(sizeInBytes = 1) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index 373b1ffa83d2..b68432b1a128 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -72,6 +72,21 @@ class ConstraintPropagationSuite extends SparkFunSuite { IsNotNull(resolveColumn(tr, "c")))) } + test("propagating constraints in aggregate") { + val tr = LocalRelation('a.int, 'b.string, 'c.int) + + assert(tr.analyze.constraints.isEmpty) + + val aliasedRelation = tr.where('c.attr > 10 && 'a.attr < 5) + .groupBy('a, 'c, 'b)('a, 'c.as("c1"), count('a).as("a3")).select('c1, 'a).analyze + + verifyConstraints(aliasedRelation.analyze.constraints, + Set(resolveColumn(aliasedRelation.analyze, "c1") > 10, + IsNotNull(resolveColumn(aliasedRelation.analyze, "c1")), + resolveColumn(aliasedRelation.analyze, "a") < 5, + IsNotNull(resolveColumn(aliasedRelation.analyze, "a")))) + } + test("propagating constraints in aliases") { val tr = LocalRelation('a.int, 'b.string, 'c.int) From 17a253cbf4712dbeab06c454b5142917a1bba78b Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 29 Feb 2016 11:02:45 -0800 Subject: [PATCH 10/28] [SPARK-13522][CORE] Executor should kill itself when it's unable to heartbeat to driver more than N times ## What changes were proposed in this pull request? Sometimes, network disconnection event won't be triggered for other potential race conditions that we may not have thought of, then the executor will keep sending heartbeats to driver and won't exit. This PR adds a new configuration `spark.executor.heartbeat.maxFailures` to kill Executor when it's unable to heartbeat to the driver more than `spark.executor.heartbeat.maxFailures` times. ## How was this patch tested? unit tests Author: Shixiong Zhu Closes #11401 from zsxwing/SPARK-13522. --- .../org/apache/spark/executor/Executor.scala | 22 ++++++++++++++++++- .../spark/executor/ExecutorExitCode.scala | 8 +++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index a602fcac68a6..86c121f78746 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -114,6 +114,19 @@ private[spark] class Executor( private val heartbeatReceiverRef = RpcUtils.makeDriverRef(HeartbeatReceiver.ENDPOINT_NAME, conf, env.rpcEnv) + /** + * When an executor is unable to send heartbeats to the driver more than `HEARTBEAT_MAX_FAILURES` + * times, it should kill itself. The default value is 60. It means we will retry to send + * heartbeats about 10 minutes because the heartbeat interval is 10s. + */ + private val HEARTBEAT_MAX_FAILURES = conf.getInt("spark.executor.heartbeat.maxFailures", 60) + + /** + * Count the failure times of heartbeat. It should only be acessed in the heartbeat thread. Each + * successful heartbeat will reset it to 0. + */ + private var heartbeatFailures = 0 + startDriverHeartbeater() def launchTask( @@ -461,8 +474,15 @@ private[spark] class Executor( logInfo("Told to re-register on heartbeat") env.blockManager.reregister() } + heartbeatFailures = 0 } catch { - case NonFatal(e) => logWarning("Issue communicating with driver in heartbeater", e) + case NonFatal(e) => + logWarning("Issue communicating with driver in heartbeater", e) + logError(s"Unable to send heartbeats to driver more than $HEARTBEAT_MAX_FAILURES times") + heartbeatFailures += 1 + if (heartbeatFailures >= HEARTBEAT_MAX_FAILURES) { + System.exit(ExecutorExitCode.HEARTBEAT_FAILURE) + } } } diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala index ea36fb60bd54..99858f785600 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala @@ -39,6 +39,12 @@ object ExecutorExitCode { /** ExternalBlockStore failed to create a local temporary directory after many attempts. */ val EXTERNAL_BLOCK_STORE_FAILED_TO_CREATE_DIR = 55 + /** + * Executor is unable to send heartbeats to the driver more than + * "spark.executor.heartbeat.maxFailures" times. + */ + val HEARTBEAT_FAILURE = 56 + def explainExitCode(exitCode: Int): String = { exitCode match { case UNCAUGHT_EXCEPTION => "Uncaught exception" @@ -51,6 +57,8 @@ object ExecutorExitCode { // TODO: replace external block store with concrete implementation name case EXTERNAL_BLOCK_STORE_FAILED_TO_CREATE_DIR => "ExternalBlockStore failed to create a local temporary directory." + case HEARTBEAT_FAILURE => + "Unable to send heartbeats to driver." case _ => "Unknown executor exit code (" + exitCode + ")" + ( if (exitCode > 128) { From 644dbb641afd337ca39733da5153239cf39cdd81 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 29 Feb 2016 11:52:11 -0800 Subject: [PATCH 11/28] [SPARK-13522][CORE] Fix the exit log place for heartbeat ## What changes were proposed in this pull request? Just fixed the log place introduced by #11401 ## How was this patch tested? unit tests. Author: Shixiong Zhu Closes #11432 from zsxwing/SPARK-13522-follow-up. --- core/src/main/scala/org/apache/spark/executor/Executor.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 86c121f78746..a959f200d4cc 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -478,9 +478,10 @@ private[spark] class Executor( } catch { case NonFatal(e) => logWarning("Issue communicating with driver in heartbeater", e) - logError(s"Unable to send heartbeats to driver more than $HEARTBEAT_MAX_FAILURES times") heartbeatFailures += 1 if (heartbeatFailures >= HEARTBEAT_MAX_FAILURES) { + logError(s"Exit as unable to send heartbeats to driver " + + s"more than $HEARTBEAT_MAX_FAILURES times") System.exit(ExecutorExitCode.HEARTBEAT_FAILURE) } } From 4bd697da03079c26fd4409dc128dbff28c737701 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Mon, 29 Feb 2016 12:59:46 -0800 Subject: [PATCH 12/28] [SPARK-13123][SQL] Implement whole state codegen for sort ## What changes were proposed in this pull request? This PR adds support for implementing whole state codegen for sort. Builds heaving on nongli 's PR: https://github.com/apache/spark/pull/11008 (which actually implements the feature), and adds the following changes on top: - [x] Generated code updates peak execution memory metrics - [x] Unit tests in `WholeStageCodegenSuite` and `SQLMetricsSuite` ## How was this patch tested? New unit tests in `WholeStageCodegenSuite` and `SQLMetricsSuite`. Further, all existing sort tests should pass. Author: Sameer Agarwal Author: Nong Li Closes #11359 from sameeragarwal/sort-codegen. --- .../execution/UnsafeExternalRowSorter.java | 9 +- .../org/apache/spark/sql/execution/Sort.scala | 124 ++++++++++++++---- .../sql/execution/WholeStageCodegen.scala | 8 +- .../execution/WholeStageCodegenSuite.scala | 9 ++ .../execution/metric/SQLMetricsSuite.scala | 7 + 5 files changed, 122 insertions(+), 35 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 27ae62f1212f..0ad0f4976c77 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -36,7 +36,7 @@ import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter; import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterIterator; -final class UnsafeExternalRowSorter { +public final class UnsafeExternalRowSorter { /** * If positive, forces records to be spilled to disk at the given frequency (measured in numbers @@ -84,8 +84,7 @@ void setTestSpillFrequency(int frequency) { testSpillFrequency = frequency; } - @VisibleForTesting - void insertRow(UnsafeRow row) throws IOException { + public void insertRow(UnsafeRow row) throws IOException { final long prefix = prefixComputer.computePrefix(row); sorter.insertRecord( row.getBaseObject(), @@ -110,8 +109,7 @@ private void cleanupResources() { sorter.cleanupResources(); } - @VisibleForTesting - Iterator sort() throws IOException { + public Iterator sort() throws IOException { try { final UnsafeSorterIterator sortedIterator = sorter.getSortedIterator(); if (!sortedIterator.hasNext()) { @@ -160,7 +158,6 @@ public UnsafeRow next() { } } - public Iterator sort(Iterator inputIterator) throws IOException { while (inputIterator.hasNext()) { insertRow(inputIterator.next()); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala index 75cb6d1137c3..2ea889ea72c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.execution -import org.apache.spark.{InternalAccumulator, SparkEnv, TaskContext} +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.executor.TaskMetrics import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.plans.physical.{Distribution, OrderedDistribution, UnspecifiedDistribution} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -37,7 +39,7 @@ case class Sort( global: Boolean, child: SparkPlan, testSpillFrequency: Int = 0) - extends UnaryNode { + extends UnaryNode with CodegenSupport { override def output: Seq[Attribute] = child.output @@ -50,34 +52,36 @@ case class Sort( "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size")) - protected override def doExecute(): RDD[InternalRow] = { - val schema = child.schema - val childOutput = child.output + def createSorter(): UnsafeExternalRowSorter = { + val ordering = newOrdering(sortOrder, output) + + // The comparator for comparing prefix + val boundSortExpression = BindReferences.bindReference(sortOrder.head, output) + val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression) + + // The generator for prefix + val prefixProjection = UnsafeProjection.create(Seq(SortPrefix(boundSortExpression))) + val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer { + override def computePrefix(row: InternalRow): Long = { + prefixProjection.apply(row).getLong(0) + } + } + val pageSize = SparkEnv.get.memoryManager.pageSizeBytes + val sorter = new UnsafeExternalRowSorter( + schema, ordering, prefixComparator, prefixComputer, pageSize) + if (testSpillFrequency > 0) { + sorter.setTestSpillFrequency(testSpillFrequency) + } + sorter + } + + protected override def doExecute(): RDD[InternalRow] = { val dataSize = longMetric("dataSize") val spillSize = longMetric("spillSize") child.execute().mapPartitionsInternal { iter => - val ordering = newOrdering(sortOrder, childOutput) - - // The comparator for comparing prefix - val boundSortExpression = BindReferences.bindReference(sortOrder.head, childOutput) - val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression) - - // The generator for prefix - val prefixProjection = UnsafeProjection.create(Seq(SortPrefix(boundSortExpression))) - val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer { - override def computePrefix(row: InternalRow): Long = { - prefixProjection.apply(row).getLong(0) - } - } - - val pageSize = SparkEnv.get.memoryManager.pageSizeBytes - val sorter = new UnsafeExternalRowSorter( - schema, ordering, prefixComparator, prefixComputer, pageSize) - if (testSpillFrequency > 0) { - sorter.setTestSpillFrequency(testSpillFrequency) - } + val sorter = createSorter() val metrics = TaskContext.get().taskMetrics() // Remember spill data size of this task before execute this operator so that we can @@ -93,4 +97,74 @@ case class Sort( sortedIterator } } + + override def upstreams(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].upstreams() + } + + // Name of sorter variable used in codegen. + private var sorterVariable: String = _ + + override protected def doProduce(ctx: CodegenContext): String = { + val needToSort = ctx.freshName("needToSort") + ctx.addMutableState("boolean", needToSort, s"$needToSort = true;") + + + // Initialize the class member variables. This includes the instance of the Sorter and + // the iterator to return sorted rows. + val thisPlan = ctx.addReferenceObj("plan", this) + sorterVariable = ctx.freshName("sorter") + ctx.addMutableState(classOf[UnsafeExternalRowSorter].getName, sorterVariable, + s"$sorterVariable = $thisPlan.createSorter();") + val metrics = ctx.freshName("metrics") + ctx.addMutableState(classOf[TaskMetrics].getName, metrics, + s"$metrics = org.apache.spark.TaskContext.get().taskMetrics();") + val sortedIterator = ctx.freshName("sortedIter") + ctx.addMutableState("scala.collection.Iterator", sortedIterator, "") + + val addToSorter = ctx.freshName("addToSorter") + ctx.addNewFunction(addToSorter, + s""" + | private void $addToSorter() throws java.io.IOException { + | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} + | } + """.stripMargin.trim) + + val outputRow = ctx.freshName("outputRow") + val dataSize = metricTerm(ctx, "dataSize") + val spillSize = metricTerm(ctx, "spillSize") + val spillSizeBefore = ctx.freshName("spillSizeBefore") + s""" + | if ($needToSort) { + | $addToSorter(); + | Long $spillSizeBefore = $metrics.memoryBytesSpilled(); + | $sortedIterator = $sorterVariable.sort(); + | $dataSize.add($sorterVariable.getPeakMemoryUsage()); + | $spillSize.add($metrics.memoryBytesSpilled() - $spillSizeBefore); + | $metrics.incPeakExecutionMemory($sorterVariable.getPeakMemoryUsage()); + | $needToSort = false; + | } + | + | while ($sortedIterator.hasNext()) { + | UnsafeRow $outputRow = (UnsafeRow)$sortedIterator.next(); + | ${consume(ctx, null, outputRow)} + | if (shouldStop()) return; + | } + """.stripMargin.trim + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + val colExprs = child.output.zipWithIndex.map { case (attr, i) => + BoundReference(i, attr.dataType, attr.nullable) + } + + ctx.currentVars = input + val code = GenerateUnsafeProjection.createCode(ctx, colExprs) + + s""" + | // Convert the input attributes to an UnsafeRow and add it to the sorter + | ${code.code} + | $sorterVariable.insertRow(${code.value}); + """.stripMargin.trim + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index afaddcf35775..cb68ca6ada36 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -287,7 +287,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) ${code.trim} } } - """ + """.trim // try to compile, helpful for debug val cleanedSource = CodeFormatter.stripExtraNewLines(source) @@ -338,7 +338,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) // There is an UnsafeRow already s""" |append($row.copy()); - """.stripMargin + """.stripMargin.trim } else { assert(input != null) if (input.nonEmpty) { @@ -351,12 +351,12 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) s""" |${code.code.trim} |append(${code.value}.copy()); - """.stripMargin + """.stripMargin.trim } else { // There is no columns s""" |append(unsafeRow); - """.stripMargin + """.stripMargin.trim } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 9350205d791d..de371d85d9fd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -69,4 +69,13 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[BroadcastHashJoin]).isDefined) assert(df.collect() === Array(Row(1, 1, "1"), Row(1, 1, "1"), Row(2, 2, "2"))) } + + test("Sort should be included in WholeStageCodegen") { + val df = sqlContext.range(3, 0, -1).sort(col("id")) + val plan = df.queryExecution.executedPlan + assert(plan.find(p => + p.isInstanceOf[WholeStageCodegen] && + p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[Sort]).isDefined) + assert(df.collect() === Array(Row(1), Row(2), Row(3))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index c49f2439fce4..5b4f6f1d2461 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -154,6 +154,13 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { ) } + test("Sort metrics") { + // Assume the execution plan is + // WholeStageCodegen(nodeId = 0, Range(nodeId = 2) -> Sort(nodeId = 1)) + val df = sqlContext.range(10).sort('id) + testSparkPlanMetrics(df, 2, Map.empty) + } + test("SortMergeJoin metrics") { // Because SortMergeJoin may skip different rows if the number of partitions is different, this // test should use the deterministic number of partitions. From c7fccb56cd9260b8d72572e65f8e46a14707b9a5 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 29 Feb 2016 13:01:27 -0800 Subject: [PATCH 13/28] [SPARK-13478][YARN] Use real user when fetching delegation tokens. The Hive client library is not smart enough to notice that the current user is a proxy user; so when using a proxy user, it fails to fetch delegation tokens from the metastore because of a missing kerberos TGT for the current user. To fix it, just run the code that fetches the delegation token as the real logged in user. Tested on a kerberos cluster both submitting normally and with a proxy user; Hive and HBase tokens are retrieved correctly in both cases. Author: Marcelo Vanzin Closes #11358 from vanzin/SPARK-13478. --- .../spark/deploy/SparkSubmitArguments.scala | 5 ++ .../deploy/yarn/YarnSparkHadoopUtil.scala | 46 ++++++++++++++----- .../yarn/YarnSparkHadoopUtilSuite.scala | 2 +- 3 files changed, 41 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 915ef81b4eae..175756b80b6b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -255,6 +255,10 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S "either HADOOP_CONF_DIR or YARN_CONF_DIR must be set in the environment.") } } + + if (proxyUser != null && principal != null) { + SparkSubmit.printErrorAndExit("Only one of --proxy-user or --principal can be provided.") + } } private def validateKillArguments(): Unit = { @@ -517,6 +521,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | --executor-memory MEM Memory per executor (e.g. 1000M, 2G) (Default: 1G). | | --proxy-user NAME User to impersonate when submitting the application. + | This argument does not work with --principal / --keytab. | | --help, -h Show this help message and exit | --verbose, -v Print additional debug output diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 4c9432dbd6ab..aef78fdfd4c5 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -18,7 +18,9 @@ package org.apache.spark.deploy.yarn import java.io.File +import java.lang.reflect.UndeclaredThrowableException import java.nio.charset.StandardCharsets.UTF_8 +import java.security.PrivilegedExceptionAction import java.util.regex.Matcher import java.util.regex.Pattern @@ -194,7 +196,7 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { */ def obtainTokenForHiveMetastore(conf: Configuration): Option[Token[DelegationTokenIdentifier]] = { try { - obtainTokenForHiveMetastoreInner(conf, UserGroupInformation.getCurrentUser().getUserName) + obtainTokenForHiveMetastoreInner(conf) } catch { case e: ClassNotFoundException => logInfo(s"Hive class not found $e") @@ -209,8 +211,8 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { * @param username the username of the principal requesting the delegating token. * @return a delegation token */ - private[yarn] def obtainTokenForHiveMetastoreInner(conf: Configuration, - username: String): Option[Token[DelegationTokenIdentifier]] = { + private[yarn] def obtainTokenForHiveMetastoreInner(conf: Configuration): + Option[Token[DelegationTokenIdentifier]] = { val mirror = universe.runtimeMirror(Utils.getContextOrSparkClassLoader) // the hive configuration class is a subclass of Hadoop Configuration, so can be cast down @@ -225,11 +227,12 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { // Check for local metastore if (metastoreUri.nonEmpty) { - require(username.nonEmpty, "Username undefined") val principalKey = "hive.metastore.kerberos.principal" val principal = hiveConf.getTrimmed(principalKey, "") require(principal.nonEmpty, "Hive principal $principalKey undefined") - logDebug(s"Getting Hive delegation token for $username against $principal at $metastoreUri") + val currentUser = UserGroupInformation.getCurrentUser() + logDebug(s"Getting Hive delegation token for ${currentUser.getUserName()} against " + + s"$principal at $metastoreUri") val hiveClass = mirror.classLoader.loadClass("org.apache.hadoop.hive.ql.metadata.Hive") val closeCurrent = hiveClass.getMethod("closeCurrent") try { @@ -238,12 +241,14 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { classOf[String], classOf[String]) val getHive = hiveClass.getMethod("get", hiveConfClass) - // invoke - val hive = getHive.invoke(null, hiveConf) - val tokenStr = getDelegationToken.invoke(hive, username, principal).asInstanceOf[String] - val hive2Token = new Token[DelegationTokenIdentifier]() - hive2Token.decodeFromUrlString(tokenStr) - Some(hive2Token) + doAsRealUser { + val hive = getHive.invoke(null, hiveConf) + val tokenStr = getDelegationToken.invoke(hive, currentUser.getUserName(), principal) + .asInstanceOf[String] + val hive2Token = new Token[DelegationTokenIdentifier]() + hive2Token.decodeFromUrlString(tokenStr) + Some(hive2Token) + } } finally { Utils.tryLogNonFatalError { closeCurrent.invoke(null) @@ -303,6 +308,25 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { } } + /** + * Run some code as the real logged in user (which may differ from the current user, for + * example, when using proxying). + */ + private def doAsRealUser[T](fn: => T): T = { + val currentUser = UserGroupInformation.getCurrentUser() + val realUser = Option(currentUser.getRealUser()).getOrElse(currentUser) + + // For some reason the Scala-generated anonymous class ends up causing an + // UndeclaredThrowableException, even if you annotate the method with @throws. + try { + realUser.doAs(new PrivilegedExceptionAction[T]() { + override def run(): T = fn + }) + } catch { + case e: UndeclaredThrowableException => throw Option(e.getCause()).getOrElse(e) + } + } + } object YarnSparkHadoopUtil { diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala index d3acaf229cc8..9202bd892f01 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala @@ -255,7 +255,7 @@ class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging hadoopConf.set("hive.metastore.uris", "http://localhost:0") val util = new YarnSparkHadoopUtil assertNestedHiveException(intercept[InvocationTargetException] { - util.obtainTokenForHiveMetastoreInner(hadoopConf, "alice") + util.obtainTokenForHiveMetastoreInner(hadoopConf) }) assertNestedHiveException(intercept[InvocationTargetException] { util.obtainTokenForHiveMetastore(hadoopConf) From 0a4b620f3144d68232eb7914ae05563aab648ced Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Mon, 29 Feb 2016 22:24:43 -0800 Subject: [PATCH 14/28] [SPARK-13551][MLLIB] Fix wrong comment and remove meanless lines in mllib.JavaBisectingKMeansExample JIRA: https://issues.apache.org/jira/browse/SPARK-13551 ## What changes were proposed in this pull request? Fix wrong comment and remove meanless lines in mllib.JavaBisectingKMeansExample ## How was this patch tested? manual test Author: Zheng RuiFeng Closes #11429 from zhengruifeng/mllib_bkm_je. --- .../spark/examples/mllib/JavaBisectingKMeansExample.java | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaBisectingKMeansExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaBisectingKMeansExample.java index 0001500f4fa5..c600094947d5 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaBisectingKMeansExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaBisectingKMeansExample.java @@ -33,7 +33,7 @@ // $example off$ /** - * Java example for graph clustering using power iteration clustering (PIC). + * Java example for bisecting k-means clustering. */ public class JavaBisectingKMeansExample { public static void main(String[] args) { @@ -54,9 +54,7 @@ public static void main(String[] args) { BisectingKMeansModel model = bkm.run(data); System.out.println("Compute Cost: " + model.computeCost(data)); - for (Vector center: model.clusterCenters()) { - System.out.println(""); - } + Vector[] clusterCenters = model.clusterCenters(); for (int i = 0; i < clusterCenters.length; i++) { Vector clusterCenter = clusterCenters[i]; From 3c5f5e3b5c4eb69472fdd8124aa9988bd8d933b5 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Mon, 29 Feb 2016 23:55:26 -0800 Subject: [PATCH 15/28] [SPARK-13550][ML] Add java example for ml.clustering.BisectingKMeans JIRA: https://issues.apache.org/jira/browse/SPARK-13550 ## What changes were proposed in this pull request? Just add a java example for ml.clustering.BisectingKMeans ## How was this patch tested? manual tests were done. (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Author: Zheng RuiFeng Closes #11428 from zhengruifeng/ml_bkm_je. --- .../ml/JavaBisectingKMeansExample.java | 81 +++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaBisectingKMeansExample.java diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBisectingKMeansExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBisectingKMeansExample.java new file mode 100644 index 000000000000..e124c1cf1855 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaBisectingKMeansExample.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import java.util.Arrays; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +// $example on$ +import org.apache.spark.ml.clustering.BisectingKMeans; +import org.apache.spark.ml.clustering.BisectingKMeansModel; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + + +/** + * An example demonstrating a bisecting k-means clustering. + */ +public class JavaBisectingKMeansExample { + + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaBisectingKMeansExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(Vectors.dense(0.1, 0.1, 0.1)), + RowFactory.create(Vectors.dense(0.3, 0.3, 0.25)), + RowFactory.create(Vectors.dense(0.1, 0.1, -0.1)), + RowFactory.create(Vectors.dense(20.3, 20.1, 19.9)), + RowFactory.create(Vectors.dense(20.2, 20.1, 19.7)), + RowFactory.create(Vectors.dense(18.9, 20.0, 19.7)) + )); + + StructType schema = new StructType(new StructField[]{ + new StructField("features", new VectorUDT(), false, Metadata.empty()), + }); + + DataFrame dataset = jsql.createDataFrame(data, schema); + + BisectingKMeans bkm = new BisectingKMeans().setK(2); + BisectingKMeansModel model = bkm.fit(dataset); + + System.out.println("Compute Cost: " + model.computeCost(dataset)); + + Vector[] clusterCenters = model.clusterCenters(); + for (int i = 0; i < clusterCenters.length; i++) { + Vector clusterCenter = clusterCenters[i]; + System.out.println("Cluster Center " + i + ": " + clusterCenter); + } + // $example off$ + + jsc.stop(); + } +} From 12a2a57e1af21da0aa4275971365d76a8fc84a43 Mon Sep 17 00:00:00 2001 From: Masayoshi TSUZUKI Date: Tue, 1 Mar 2016 14:37:36 +0000 Subject: [PATCH 16/28] [SPARK-13592][WINDOWS] fix path of spark-submit2.cmd in spark-submit.cmd ## What changes were proposed in this pull request? This patch fixes the problem that pyspark fails on Windows because pyspark can't find ```spark-submit2.cmd```. ## How was this patch tested? manual tests: I ran ```bin\pyspark.cmd``` and checked if pyspark is launched correctly after this patch is applyed. Author: Masayoshi TSUZUKI Closes #11442 from tsudukim/feature/SPARK-13592. --- bin/spark-submit.cmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bin/spark-submit.cmd b/bin/spark-submit.cmd index f121b62a53d2..f301606933a9 100644 --- a/bin/spark-submit.cmd +++ b/bin/spark-submit.cmd @@ -20,4 +20,4 @@ rem rem This is the entry point for running Spark submit. To avoid polluting the rem environment, it just launches a new cmd to do the real work. -cmd /V /E /C spark-submit2.cmd %* +cmd /V /E /C "%~dp0spark-submit2.cmd" %* From c43899a04e4de18e238a1761bf4fe9f54e182320 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 1 Mar 2016 08:43:02 -0800 Subject: [PATCH 17/28] [SPARK-13511] [SQL] Add wholestage codegen for limit JIRA: https://issues.apache.org/jira/browse/SPARK-13511 ## What changes were proposed in this pull request? Current limit operator doesn't support wholestage codegen. This is open to add support for it. In the `doConsume` of `GlobalLimit` and `LocalLimit`, we use a count term to count the processed rows. Once the row numbers catches the limit number, we set the variable `stopEarly` of `BufferedRowIterator` newly added in this pr to `true` that indicates we want to stop processing remaining rows. Then when the wholestage codegen framework checks `shouldStop()`, it will stop the processing of the row iterator. Before this, the executed plan for a query `sqlContext.range(N).limit(100).groupBy().sum()` is: TungstenAggregate(key=[], functions=[(sum(id#5L),mode=Final,isDistinct=false)], output=[sum(id)#6L]) +- TungstenAggregate(key=[], functions=[(sum(id#5L),mode=Partial,isDistinct=false)], output=[sum#9L]) +- GlobalLimit 100 +- Exchange SinglePartition, None +- LocalLimit 100 +- Range 0, 1, 1, 524288000, [id#5L] After add wholestage codegen support: WholeStageCodegen : +- TungstenAggregate(key=[], functions=[(sum(id#40L),mode=Final,isDistinct=false)], output=[sum(id)#41L]) : +- TungstenAggregate(key=[], functions=[(sum(id#40L),mode=Partial,isDistinct=false)], output=[sum#44L]) : +- GlobalLimit 100 : +- INPUT +- Exchange SinglePartition, None +- WholeStageCodegen : +- LocalLimit 100 : +- Range 0, 1, 1, 524288000, [id#40L] ## How was this patch tested? A test is added into BenchmarkWholeStageCodegen. Author: Liang-Chi Hsieh Closes #11391 from viirya/wholestage-limit. --- .../apache/spark/sql/execution/limit.scala | 35 +++++++++++++++++-- .../BenchmarkWholeStageCodegen.scala | 14 ++++++++ 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index cd543d419528..45175d36d5c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -21,9 +21,10 @@ import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, LazilyGeneratedOrdering} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.exchange.ShuffleExchange +import org.apache.spark.sql.execution.metric.SQLMetrics /** @@ -48,7 +49,7 @@ case class CollectLimit(limit: Int, child: SparkPlan) extends UnaryNode { /** * Helper trait which defines methods that are shared by both [[LocalLimit]] and [[GlobalLimit]]. */ -trait BaseLimit extends UnaryNode { +trait BaseLimit extends UnaryNode with CodegenSupport { val limit: Int override def output: Seq[Attribute] = child.output override def outputOrdering: Seq[SortOrder] = child.outputOrdering @@ -56,6 +57,36 @@ trait BaseLimit extends UnaryNode { protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => iter.take(limit) } + + override def upstreams(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].upstreams() + } + + protected override def doProduce(ctx: CodegenContext): String = { + child.asInstanceOf[CodegenSupport].produce(ctx, this) + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + val stopEarly = ctx.freshName("stopEarly") + ctx.addMutableState("boolean", stopEarly, s"$stopEarly = false;") + + ctx.addNewFunction("shouldStop", s""" + @Override + protected boolean shouldStop() { + return !currentRows.isEmpty() || $stopEarly; + } + """) + val countTerm = ctx.freshName("count") + ctx.addMutableState("int", countTerm, s"$countTerm = 0;") + s""" + | if ($countTerm < $limit) { + | $countTerm += 1; + | ${consume(ctx, input)} + | } else { + | $stopEarly = true; + | } + """.stripMargin + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index 6d6cc0186a96..2d3e34d0e129 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -70,6 +70,20 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { */ } + ignore("range/limit/sum") { + val N = 500 << 20 + runBenchmark("range/limit/sum", N) { + sqlContext.range(N).limit(1000000).groupBy().sum().collect() + } + /* + Westmere E56xx/L56xx/X56xx (Nehalem-C) + range/limit/sum: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + range/limit/sum codegen=false 609 / 672 861.6 1.2 1.0X + range/limit/sum codegen=true 561 / 621 935.3 1.1 1.1X + */ + } + ignore("stat functions") { val N = 100 << 20 From 5ed48dd84d38dfe621428e164a02e74ddbdbc622 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 1 Mar 2016 08:47:56 -0800 Subject: [PATCH 18/28] [SPARK-12811][ML] Estimator for Generalized Linear Models(GLMs) Estimator for Generalized Linear Models(GLMs) which will be solved by IRLS. cc mengxr Author: Yanbo Liang Closes #11136 from yanboliang/spark-12811. --- .../spark/ml/optim/WeightedLeastSquares.scala | 10 +- .../GeneralizedLinearRegression.scala | 577 ++++++++++++++++++ .../ml/regression/LinearRegression.scala | 4 +- .../GeneralizedLinearRegressionSuite.scala | 507 +++++++++++++++ 4 files changed, 1094 insertions(+), 4 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala index 61b364213181..55b751065664 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -156,6 +156,12 @@ private[ml] class WeightedLeastSquares( private[ml] object WeightedLeastSquares { + /** + * In order to take the normal equation approach efficiently, [[WeightedLeastSquares]] + * only supports the number of features is no more than 4096. + */ + val MAX_NUM_FEATURES: Int = 4096 + /** * Aggregator to provide necessary summary statistics for solving [[WeightedLeastSquares]]. */ @@ -174,8 +180,8 @@ private[ml] object WeightedLeastSquares { private var aaSum: DenseVector = _ private def init(k: Int): Unit = { - require(k <= 4096, "In order to take the normal equation approach efficiently, " + - s"we set the max number of features to 4096 but got $k.") + require(k <= MAX_NUM_FEATURES, "In order to take the normal equation approach efficiently, " + + s"we set the max number of features to $MAX_NUM_FEATURES but got $k.") this.k = k triK = k * (k + 1) / 2 count = 0L diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala new file mode 100644 index 000000000000..a850dfee0a45 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -0,0 +1,577 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import breeze.stats.distributions.{Gaussian => GD} + +import org.apache.spark.{Logging, SparkException} +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.PredictorParams +import org.apache.spark.ml.feature.Instance +import org.apache.spark.ml.optim._ +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.mllib.linalg.{BLAS, Vector} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions._ + +/** + * Params for Generalized Linear Regression. + */ +private[regression] trait GeneralizedLinearRegressionBase extends PredictorParams + with HasFitIntercept with HasMaxIter with HasTol with HasRegParam with HasWeightCol + with HasSolver with Logging { + + /** + * Param for the name of family which is a description of the error distribution + * to be used in the model. + * Supported options: "gaussian", "binomial", "poisson" and "gamma". + * Default is "gaussian". + * @group param + */ + @Since("2.0.0") + final val family: Param[String] = new Param(this, "family", + "The name of family which is a description of the error distribution to be used in the " + + "model. Supported options: gaussian(default), binomial, poisson and gamma.", + ParamValidators.inArray[String](GeneralizedLinearRegression.supportedFamilyNames.toArray)) + + /** @group getParam */ + @Since("2.0.0") + def getFamily: String = $(family) + + /** + * Param for the name of link function which provides the relationship + * between the linear predictor and the mean of the distribution function. + * Supported options: "identity", "log", "inverse", "logit", "probit", "cloglog" and "sqrt". + * @group param + */ + @Since("2.0.0") + final val link: Param[String] = new Param(this, "link", "The name of link function " + + "which provides the relationship between the linear predictor and the mean of the " + + "distribution function. Supported options: identity, log, inverse, logit, probit, " + + "cloglog and sqrt.", + ParamValidators.inArray[String](GeneralizedLinearRegression.supportedLinkNames.toArray)) + + /** @group getParam */ + @Since("2.0.0") + def getLink: String = $(link) + + import GeneralizedLinearRegression._ + + @Since("2.0.0") + override def validateParams(): Unit = { + if ($(solver) == "irls") { + setDefault(maxIter -> 25) + } + if (isDefined(link)) { + require(supportedFamilyAndLinkPairs.contains( + Family.fromName($(family)) -> Link.fromName($(link))), "Generalized Linear Regression " + + s"with ${$(family)} family does not support ${$(link)} link function.") + } + } +} + +/** + * :: Experimental :: + * + * Fit a Generalized Linear Model ([[https://en.wikipedia.org/wiki/Generalized_linear_model]]) + * specified by giving a symbolic description of the linear predictor (link function) and + * a description of the error distribution (family). + * It supports "gaussian", "binomial", "poisson" and "gamma" as family. + * Valid link functions for each family is listed below. The first link function of each family + * is the default one. + * - "gaussian" -> "identity", "log", "inverse" + * - "binomial" -> "logit", "probit", "cloglog" + * - "poisson" -> "log", "identity", "sqrt" + * - "gamma" -> "inverse", "identity", "log" + */ +@Experimental +@Since("2.0.0") +class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val uid: String) + extends Regressor[Vector, GeneralizedLinearRegression, GeneralizedLinearRegressionModel] + with GeneralizedLinearRegressionBase with Logging { + + import GeneralizedLinearRegression._ + + @Since("2.0.0") + def this() = this(Identifiable.randomUID("glm")) + + /** + * Sets the value of param [[family]]. + * Default is "gaussian". + * @group setParam + */ + @Since("2.0.0") + def setFamily(value: String): this.type = set(family, value) + setDefault(family -> Gaussian.name) + + /** + * Sets the value of param [[link]]. + * @group setParam + */ + @Since("2.0.0") + def setLink(value: String): this.type = set(link, value) + + /** + * Sets if we should fit the intercept. + * Default is true. + * @group setParam + */ + @Since("2.0.0") + def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value) + + /** + * Sets the maximum number of iterations. + * Default is 25 if the solver algorithm is "irls". + * @group setParam + */ + @Since("2.0.0") + def setMaxIter(value: Int): this.type = set(maxIter, value) + + /** + * Sets the convergence tolerance of iterations. + * Smaller value will lead to higher accuracy with the cost of more iterations. + * Default is 1E-6. + * @group setParam + */ + @Since("2.0.0") + def setTol(value: Double): this.type = set(tol, value) + setDefault(tol -> 1E-6) + + /** + * Sets the regularization parameter. + * Default is 0.0. + * @group setParam + */ + @Since("2.0.0") + def setRegParam(value: Double): this.type = set(regParam, value) + setDefault(regParam -> 0.0) + + /** + * Sets the value of param [[weightCol]]. + * If this is not set or empty, we treat all instance weights as 1.0. + * Default is empty, so all instances have weight one. + * @group setParam + */ + @Since("2.0.0") + def setWeightCol(value: String): this.type = set(weightCol, value) + setDefault(weightCol -> "") + + /** + * Sets the solver algorithm used for optimization. + * Currently only support "irls" which is also the default solver. + * @group setParam + */ + @Since("2.0.0") + def setSolver(value: String): this.type = set(solver, value) + setDefault(solver -> "irls") + + override protected def train(dataset: DataFrame): GeneralizedLinearRegressionModel = { + val familyObj = Family.fromName($(family)) + val linkObj = if (isDefined(link)) { + Link.fromName($(link)) + } else { + familyObj.defaultLink + } + val familyAndLink = new FamilyAndLink(familyObj, linkObj) + + val numFeatures = dataset.select(col($(featuresCol))).limit(1).rdd + .map { case Row(features: Vector) => + features.size + }.first() + if (numFeatures > WeightedLeastSquares.MAX_NUM_FEATURES) { + val msg = "Currently, GeneralizedLinearRegression only supports number of features" + + s" <= ${WeightedLeastSquares.MAX_NUM_FEATURES}. Found $numFeatures in the input dataset." + throw new SparkException(msg) + } + + val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol)) + val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd + .map { case Row(label: Double, weight: Double, features: Vector) => + Instance(label, weight, features) + } + + if (familyObj == Gaussian && linkObj == Identity) { + // TODO: Make standardizeFeatures and standardizeLabel configurable. + val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam), + standardizeFeatures = true, standardizeLabel = true) + val wlsModel = optimizer.fit(instances) + val model = copyValues( + new GeneralizedLinearRegressionModel(uid, wlsModel.coefficients, wlsModel.intercept) + .setParent(this)) + return model + } + + // Fit Generalized Linear Model by iteratively reweighted least squares (IRLS). + val initialModel = familyAndLink.initialize(instances, $(fitIntercept), $(regParam)) + val optimizer = new IterativelyReweightedLeastSquares(initialModel, familyAndLink.reweightFunc, + $(fitIntercept), $(regParam), $(maxIter), $(tol)) + val irlsModel = optimizer.fit(instances) + + val model = copyValues( + new GeneralizedLinearRegressionModel(uid, irlsModel.coefficients, irlsModel.intercept) + .setParent(this)) + model + } + + @Since("2.0.0") + override def copy(extra: ParamMap): GeneralizedLinearRegression = defaultCopy(extra) +} + +@Since("2.0.0") +private[ml] object GeneralizedLinearRegression { + + /** Set of family and link pairs that GeneralizedLinearRegression supports. */ + lazy val supportedFamilyAndLinkPairs = Set( + Gaussian -> Identity, Gaussian -> Log, Gaussian -> Inverse, + Binomial -> Logit, Binomial -> Probit, Binomial -> CLogLog, + Poisson -> Log, Poisson -> Identity, Poisson -> Sqrt, + Gamma -> Inverse, Gamma -> Identity, Gamma -> Log + ) + + /** Set of family names that GeneralizedLinearRegression supports. */ + lazy val supportedFamilyNames = supportedFamilyAndLinkPairs.map(_._1.name) + + /** Set of link names that GeneralizedLinearRegression supports. */ + lazy val supportedLinkNames = supportedFamilyAndLinkPairs.map(_._2.name) + + val epsilon: Double = 1E-16 + + /** + * Wrapper of family and link combination used in the model. + */ + private[ml] class FamilyAndLink(val family: Family, val link: Link) extends Serializable { + + /** Linear predictor based on given mu. */ + def predict(mu: Double): Double = link.link(family.project(mu)) + + /** Fitted value based on linear predictor eta. */ + def fitted(eta: Double): Double = family.project(link.unlink(eta)) + + /** + * Get the initial guess model for [[IterativelyReweightedLeastSquares]]. + */ + def initialize( + instances: RDD[Instance], + fitIntercept: Boolean, + regParam: Double): WeightedLeastSquaresModel = { + val newInstances = instances.map { instance => + val mu = family.initialize(instance.label, instance.weight) + val eta = predict(mu) + Instance(eta, instance.weight, instance.features) + } + // TODO: Make standardizeFeatures and standardizeLabel configurable. + val initialModel = new WeightedLeastSquares(fitIntercept, regParam, + standardizeFeatures = true, standardizeLabel = true) + .fit(newInstances) + initialModel + } + + /** + * The reweight function used to update offsets and weights + * at each iteration of [[IterativelyReweightedLeastSquares]]. + */ + val reweightFunc: (Instance, WeightedLeastSquaresModel) => (Double, Double) = { + (instance: Instance, model: WeightedLeastSquaresModel) => { + val eta = model.predict(instance.features) + val mu = fitted(eta) + val offset = eta + (instance.label - mu) * link.deriv(mu) + val weight = instance.weight / (math.pow(this.link.deriv(mu), 2.0) * family.variance(mu)) + (offset, weight) + } + } + } + + /** + * A description of the error distribution to be used in the model. + * @param name the name of the family. + */ + private[ml] abstract class Family(val name: String) extends Serializable { + + /** The default link instance of this family. */ + val defaultLink: Link + + /** Initialize the starting value for mu. */ + def initialize(y: Double, weight: Double): Double + + /** The variance of the endogenous variable's mean, given the value mu. */ + def variance(mu: Double): Double + + /** Trim the fitted value so that it will be in valid range. */ + def project(mu: Double): Double = mu + } + + private[ml] object Family { + + /** + * Gets the [[Family]] object from its name. + * @param name family name: "gaussian", "binomial", "poisson" or "gamma". + */ + def fromName(name: String): Family = { + name match { + case Gaussian.name => Gaussian + case Binomial.name => Binomial + case Poisson.name => Poisson + case Gamma.name => Gamma + } + } + } + + /** + * Gaussian exponential family distribution. + * The default link for the Gaussian family is the identity link. + */ + private[ml] object Gaussian extends Family("gaussian") { + + val defaultLink: Link = Identity + + override def initialize(y: Double, weight: Double): Double = y + + def variance(mu: Double): Double = 1.0 + + override def project(mu: Double): Double = { + if (mu.isNegInfinity) { + Double.MinValue + } else if (mu.isPosInfinity) { + Double.MaxValue + } else { + mu + } + } + } + + /** + * Binomial exponential family distribution. + * The default link for the Binomial family is the logit link. + */ + private[ml] object Binomial extends Family("binomial") { + + val defaultLink: Link = Logit + + override def initialize(y: Double, weight: Double): Double = { + val mu = (weight * y + 0.5) / (weight + 1.0) + require(mu > 0.0 && mu < 1.0, "The response variable of Binomial family" + + s"should be in range (0, 1), but got $mu") + mu + } + + override def variance(mu: Double): Double = mu * (1.0 - mu) + + override def project(mu: Double): Double = { + if (mu < epsilon) { + epsilon + } else if (mu > 1.0 - epsilon) { + 1.0 - epsilon + } else { + mu + } + } + } + + /** + * Poisson exponential family distribution. + * The default link for the Poisson family is the log link. + */ + private[ml] object Poisson extends Family("poisson") { + + val defaultLink: Link = Log + + override def initialize(y: Double, weight: Double): Double = { + require(y > 0.0, "The response variable of Poisson family " + + s"should be positive, but got $y") + y + } + + override def variance(mu: Double): Double = mu + + override def project(mu: Double): Double = { + if (mu < epsilon) { + epsilon + } else if (mu.isInfinity) { + Double.MaxValue + } else { + mu + } + } + } + + /** + * Gamma exponential family distribution. + * The default link for the Gamma family is the inverse link. + */ + private[ml] object Gamma extends Family("gamma") { + + val defaultLink: Link = Inverse + + override def initialize(y: Double, weight: Double): Double = { + require(y > 0.0, "The response variable of Gamma family " + + s"should be positive, but got $y") + y + } + + override def variance(mu: Double): Double = math.pow(mu, 2.0) + + override def project(mu: Double): Double = { + if (mu < epsilon) { + epsilon + } else if (mu.isInfinity) { + Double.MaxValue + } else { + mu + } + } + } + + /** + * A description of the link function to be used in the model. + * The link function provides the relationship between the linear predictor + * and the mean of the distribution function. + * @param name the name of link function. + */ + private[ml] abstract class Link(val name: String) extends Serializable { + + /** The link function. */ + def link(mu: Double): Double + + /** Derivative of the link function. */ + def deriv(mu: Double): Double + + /** The inverse link function. */ + def unlink(eta: Double): Double + } + + private[ml] object Link { + + /** + * Gets the [[Link]] object from its name. + * @param name link name: "identity", "logit", "log", + * "inverse", "probit", "cloglog" or "sqrt". + */ + def fromName(name: String): Link = { + name match { + case Identity.name => Identity + case Logit.name => Logit + case Log.name => Log + case Inverse.name => Inverse + case Probit.name => Probit + case CLogLog.name => CLogLog + case Sqrt.name => Sqrt + } + } + } + + private[ml] object Identity extends Link("identity") { + + override def link(mu: Double): Double = mu + + override def deriv(mu: Double): Double = 1.0 + + override def unlink(eta: Double): Double = eta + } + + private[ml] object Logit extends Link("logit") { + + override def link(mu: Double): Double = math.log(mu / (1.0 - mu)) + + override def deriv(mu: Double): Double = 1.0 / (mu * (1.0 - mu)) + + override def unlink(eta: Double): Double = 1.0 / (1.0 + math.exp(-1.0 * eta)) + } + + private[ml] object Log extends Link("log") { + + override def link(mu: Double): Double = math.log(mu) + + override def deriv(mu: Double): Double = 1.0 / mu + + override def unlink(eta: Double): Double = math.exp(eta) + } + + private[ml] object Inverse extends Link("inverse") { + + override def link(mu: Double): Double = 1.0 / mu + + override def deriv(mu: Double): Double = -1.0 * math.pow(mu, -2.0) + + override def unlink(eta: Double): Double = 1.0 / eta + } + + private[ml] object Probit extends Link("probit") { + + override def link(mu: Double): Double = GD(0.0, 1.0).icdf(mu) + + override def deriv(mu: Double): Double = 1.0 / GD(0.0, 1.0).pdf(GD(0.0, 1.0).icdf(mu)) + + override def unlink(eta: Double): Double = GD(0.0, 1.0).cdf(eta) + } + + private[ml] object CLogLog extends Link("cloglog") { + + override def link(mu: Double): Double = math.log(-1.0 * math.log(1 - mu)) + + override def deriv(mu: Double): Double = 1.0 / ((mu - 1.0) * math.log(1.0 - mu)) + + override def unlink(eta: Double): Double = 1.0 - math.exp(-1.0 * math.exp(eta)) + } + + private[ml] object Sqrt extends Link("sqrt") { + + override def link(mu: Double): Double = math.sqrt(mu) + + override def deriv(mu: Double): Double = 1.0 / (2.0 * math.sqrt(mu)) + + override def unlink(eta: Double): Double = math.pow(eta, 2.0) + } +} + +/** + * :: Experimental :: + * Model produced by [[GeneralizedLinearRegression]]. + */ +@Experimental +@Since("2.0.0") +class GeneralizedLinearRegressionModel private[ml] ( + @Since("2.0.0") override val uid: String, + @Since("2.0.0") val coefficients: Vector, + @Since("2.0.0") val intercept: Double) + extends RegressionModel[Vector, GeneralizedLinearRegressionModel] + with GeneralizedLinearRegressionBase { + + import GeneralizedLinearRegression._ + + lazy val familyObj = Family.fromName($(family)) + lazy val linkObj = if (isDefined(link)) { + Link.fromName($(link)) + } else { + familyObj.defaultLink + } + lazy val familyAndLink = new FamilyAndLink(familyObj, linkObj) + + override protected def predict(features: Vector): Double = { + val eta = BLAS.dot(features, coefficients) + intercept + familyAndLink.fitted(eta) + } + + @Since("2.0.0") + override def copy(extra: ParamMap): GeneralizedLinearRegressionModel = { + copyValues(new GeneralizedLinearRegressionModel(uid, coefficients, intercept), extra) + .setParent(parent) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 8f78fd122f34..b4f17b8e2898 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -163,8 +163,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String }.first() val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol)) - if (($(solver) == "auto" && $(elasticNetParam) == 0.0 && numFeatures <= 4096) || - $(solver) == "normal") { + if (($(solver) == "auto" && $(elasticNetParam) == 0.0 && + numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == "normal") { require($(elasticNetParam) == 0.0, "Only L2 regularization can be used when normal " + "solver is used.'") // For low dimensional data, WeightedLeastSquares is more efficiently since the diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala new file mode 100644 index 000000000000..8bfa9855ce4e --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -0,0 +1,507 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import scala.util.Random + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.mllib.classification.LogisticRegressionSuite._ +import org.apache.spark.mllib.linalg.{BLAS, DenseVector, Vectors} +import org.apache.spark.mllib.random._ +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{DataFrame, Row} + +class GeneralizedLinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { + + private val seed: Int = 42 + @transient var datasetGaussianIdentity: DataFrame = _ + @transient var datasetGaussianLog: DataFrame = _ + @transient var datasetGaussianInverse: DataFrame = _ + @transient var datasetBinomial: DataFrame = _ + @transient var datasetPoissonLog: DataFrame = _ + @transient var datasetPoissonIdentity: DataFrame = _ + @transient var datasetPoissonSqrt: DataFrame = _ + @transient var datasetGammaInverse: DataFrame = _ + @transient var datasetGammaIdentity: DataFrame = _ + @transient var datasetGammaLog: DataFrame = _ + + override def beforeAll(): Unit = { + super.beforeAll() + + import GeneralizedLinearRegressionSuite._ + + datasetGaussianIdentity = sqlContext.createDataFrame( + sc.parallelize(generateGeneralizedLinearRegressionInput( + intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "gaussian", link = "identity"), 2)) + + datasetGaussianLog = sqlContext.createDataFrame( + sc.parallelize(generateGeneralizedLinearRegressionInput( + intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "gaussian", link = "log"), 2)) + + datasetGaussianInverse = sqlContext.createDataFrame( + sc.parallelize(generateGeneralizedLinearRegressionInput( + intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "gaussian", link = "inverse"), 2)) + + datasetBinomial = { + val nPoints = 10000 + val coefficients = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191) + val xMean = Array(5.843, 3.057, 3.758, 1.199) + val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) + + val testData = + generateMultinomialLogisticInput(coefficients, xMean, xVariance, + addIntercept = true, nPoints, seed) + + sqlContext.createDataFrame(sc.parallelize(testData, 2)) + } + + datasetPoissonLog = sqlContext.createDataFrame( + sc.parallelize(generateGeneralizedLinearRegressionInput( + intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "poisson", link = "log"), 2)) + + datasetPoissonIdentity = sqlContext.createDataFrame( + sc.parallelize(generateGeneralizedLinearRegressionInput( + intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "poisson", link = "identity"), 2)) + + datasetPoissonSqrt = sqlContext.createDataFrame( + sc.parallelize(generateGeneralizedLinearRegressionInput( + intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "poisson", link = "sqrt"), 2)) + + datasetGammaInverse = sqlContext.createDataFrame( + sc.parallelize(generateGeneralizedLinearRegressionInput( + intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "gamma", link = "inverse"), 2)) + + datasetGammaIdentity = sqlContext.createDataFrame( + sc.parallelize(generateGeneralizedLinearRegressionInput( + intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "gamma", link = "identity"), 2)) + + datasetGammaLog = sqlContext.createDataFrame( + sc.parallelize(generateGeneralizedLinearRegressionInput( + intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "gamma", link = "log"), 2)) + } + + test("params") { + ParamsSuite.checkParams(new GeneralizedLinearRegression) + val model = new GeneralizedLinearRegressionModel("genLinReg", Vectors.dense(0.0), 0.0) + ParamsSuite.checkParams(model) + } + + test("generalized linear regression: default params") { + val glr = new GeneralizedLinearRegression + assert(glr.getLabelCol === "label") + assert(glr.getFeaturesCol === "features") + assert(glr.getPredictionCol === "prediction") + assert(glr.getFitIntercept) + assert(glr.getTol === 1E-6) + assert(glr.getWeightCol === "") + assert(glr.getRegParam === 0.0) + assert(glr.getSolver == "irls") + // TODO: Construct model directly instead of via fitting. + val model = glr.setFamily("gaussian").setLink("identity") + .fit(datasetGaussianIdentity) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + + assert(model.getFeaturesCol === "features") + assert(model.getPredictionCol === "prediction") + assert(model.intercept !== 0.0) + assert(model.hasParent) + assert(model.getFamily === "gaussian") + assert(model.getLink === "identity") + } + + test("generalized linear regression: gaussian family against glm") { + /* + R code: + f1 <- data$V1 ~ data$V2 + data$V3 - 1 + f2 <- data$V1 ~ data$V2 + data$V3 + + data <- read.csv("path", header=FALSE) + for (formula in c(f1, f2)) { + model <- glm(formula, family="gaussian", data=data) + print(as.vector(coef(model))) + } + + [1] 2.2960999 0.8087933 + [1] 2.5002642 2.2000403 0.5999485 + + data <- read.csv("path", header=FALSE) + model1 <- glm(f1, family=gaussian(link=log), data=data, start=c(0,0)) + model2 <- glm(f2, family=gaussian(link=log), data=data, start=c(0,0,0)) + print(as.vector(coef(model1))) + print(as.vector(coef(model2))) + + [1] 0.23069326 0.07993778 + [1] 0.25001858 0.22002452 0.05998789 + + data <- read.csv("path", header=FALSE) + for (formula in c(f1, f2)) { + model <- glm(formula, family=gaussian(link=inverse), data=data) + print(as.vector(coef(model))) + } + + [1] 2.3010179 0.8198976 + [1] 2.4108902 2.2130248 0.6086152 + */ + + val expected = Seq( + Vectors.dense(0.0, 2.2960999, 0.8087933), + Vectors.dense(2.5002642, 2.2000403, 0.5999485), + Vectors.dense(0.0, 0.23069326, 0.07993778), + Vectors.dense(0.25001858, 0.22002452, 0.05998789), + Vectors.dense(0.0, 2.3010179, 0.8198976), + Vectors.dense(2.4108902, 2.2130248, 0.6086152)) + + import GeneralizedLinearRegression._ + + var idx = 0 + for ((link, dataset) <- Seq(("identity", datasetGaussianIdentity), ("log", datasetGaussianLog), + ("inverse", datasetGaussianInverse))) { + for (fitIntercept <- Seq(false, true)) { + val trainer = new GeneralizedLinearRegression().setFamily("gaussian").setLink(link) + .setFitIntercept(fitIntercept) + val model = trainer.fit(dataset) + val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) + assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with gaussian family, " + + s"$link link and fitIntercept = $fitIntercept.") + + val familyLink = new FamilyAndLink(Gaussian, Link.fromName(link)) + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val eta = BLAS.dot(features, model.coefficients) + model.intercept + val prediction2 = familyLink.fitted(eta) + assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " + + s"gaussian family, $link link and fitIntercept = $fitIntercept.") + } + + idx += 1 + } + } + } + + test("generalized linear regression: gaussian family against glmnet") { + /* + R code: + library(glmnet) + data <- read.csv("path", header=FALSE) + label = data$V1 + features = as.matrix(data.frame(data$V2, data$V3)) + for (intercept in c(FALSE, TRUE)) { + for (lambda in c(0.0, 0.1, 1.0)) { + model <- glmnet(features, label, family="gaussian", intercept=intercept, + lambda=lambda, alpha=0, thresh=1E-14) + print(as.vector(coef(model))) + } + } + + [1] 0.0000000 2.2961005 0.8087932 + [1] 0.0000000 2.2130368 0.8309556 + [1] 0.0000000 1.7176137 0.9610657 + [1] 2.5002642 2.2000403 0.5999485 + [1] 3.1106389 2.0935142 0.5712711 + [1] 6.7597127 1.4581054 0.3994266 + */ + + val expected = Seq( + Vectors.dense(0.0, 2.2961005, 0.8087932), + Vectors.dense(0.0, 2.2130368, 0.8309556), + Vectors.dense(0.0, 1.7176137, 0.9610657), + Vectors.dense(2.5002642, 2.2000403, 0.5999485), + Vectors.dense(3.1106389, 2.0935142, 0.5712711), + Vectors.dense(6.7597127, 1.4581054, 0.3994266)) + + var idx = 0 + for (fitIntercept <- Seq(false, true); + regParam <- Seq(0.0, 0.1, 1.0)) { + val trainer = new GeneralizedLinearRegression().setFamily("gaussian") + .setFitIntercept(fitIntercept).setRegParam(regParam) + val model = trainer.fit(datasetGaussianIdentity) + val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) + assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with gaussian family, " + + s"fitIntercept = $fitIntercept and regParam = $regParam.") + + idx += 1 + } + } + + test("generalized linear regression: binomial family against glm") { + /* + R code: + f1 <- data$V1 ~ data$V2 + data$V3 + data$V4 + data$V5 - 1 + f2 <- data$V1 ~ data$V2 + data$V3 + data$V4 + data$V5 + data <- read.csv("path", header=FALSE) + + for (formula in c(f1, f2)) { + model <- glm(formula, family="binomial", data=data) + print(as.vector(coef(model))) + } + + [1] -0.3560284 1.3010002 -0.3570805 -0.7406762 + [1] 2.8367406 -0.5896187 0.8931655 -0.3925169 -0.7996989 + + for (formula in c(f1, f2)) { + model <- glm(formula, family=binomial(link=probit), data=data) + print(as.vector(coef(model))) + } + + [1] -0.2134390 0.7800646 -0.2144267 -0.4438358 + [1] 1.6995366 -0.3524694 0.5332651 -0.2352985 -0.4780850 + + for (formula in c(f1, f2)) { + model <- glm(formula, family=binomial(link=cloglog), data=data) + print(as.vector(coef(model))) + } + + [1] -0.2832198 0.8434144 -0.2524727 -0.5293452 + [1] 1.5063590 -0.4038015 0.6133664 -0.2687882 -0.5541758 + */ + val expected = Seq( + Vectors.dense(0.0, -0.3560284, 1.3010002, -0.3570805, -0.7406762), + Vectors.dense(2.8367406, -0.5896187, 0.8931655, -0.3925169, -0.7996989), + Vectors.dense(0.0, -0.2134390, 0.7800646, -0.2144267, -0.4438358), + Vectors.dense(1.6995366, -0.3524694, 0.5332651, -0.2352985, -0.4780850), + Vectors.dense(0.0, -0.2832198, 0.8434144, -0.2524727, -0.5293452), + Vectors.dense(1.5063590, -0.4038015, 0.6133664, -0.2687882, -0.5541758)) + + import GeneralizedLinearRegression._ + + var idx = 0 + for ((link, dataset) <- Seq(("logit", datasetBinomial), ("probit", datasetBinomial), + ("cloglog", datasetBinomial))) { + for (fitIntercept <- Seq(false, true)) { + val trainer = new GeneralizedLinearRegression().setFamily("binomial").setLink(link) + .setFitIntercept(fitIntercept) + val model = trainer.fit(dataset) + val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1), + model.coefficients(2), model.coefficients(3)) + assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with binomial family, " + + s"$link link and fitIntercept = $fitIntercept.") + + val familyLink = new FamilyAndLink(Binomial, Link.fromName(link)) + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val eta = BLAS.dot(features, model.coefficients) + model.intercept + val prediction2 = familyLink.fitted(eta) + assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " + + s"binomial family, $link link and fitIntercept = $fitIntercept.") + } + + idx += 1 + } + } + } + + test("generalized linear regression: poisson family against glm") { + /* + R code: + f1 <- data$V1 ~ data$V2 + data$V3 - 1 + f2 <- data$V1 ~ data$V2 + data$V3 + + data <- read.csv("path", header=FALSE) + for (formula in c(f1, f2)) { + model <- glm(formula, family="poisson", data=data) + print(as.vector(coef(model))) + } + + [1] 0.22999393 0.08047088 + [1] 0.25022353 0.21998599 0.05998621 + + data <- read.csv("path", header=FALSE) + for (formula in c(f1, f2)) { + model <- glm(formula, family=poisson(link=identity), data=data) + print(as.vector(coef(model))) + } + + [1] 2.2929501 0.8119415 + [1] 2.5012730 2.1999407 0.5999107 + + data <- read.csv("path", header=FALSE) + for (formula in c(f1, f2)) { + model <- glm(formula, family=poisson(link=sqrt), data=data) + print(as.vector(coef(model))) + } + + [1] 2.2958947 0.8090515 + [1] 2.5000480 2.1999972 0.5999968 + */ + val expected = Seq( + Vectors.dense(0.0, 0.22999393, 0.08047088), + Vectors.dense(0.25022353, 0.21998599, 0.05998621), + Vectors.dense(0.0, 2.2929501, 0.8119415), + Vectors.dense(2.5012730, 2.1999407, 0.5999107), + Vectors.dense(0.0, 2.2958947, 0.8090515), + Vectors.dense(2.5000480, 2.1999972, 0.5999968)) + + import GeneralizedLinearRegression._ + + var idx = 0 + for ((link, dataset) <- Seq(("log", datasetPoissonLog), ("identity", datasetPoissonIdentity), + ("sqrt", datasetPoissonSqrt))) { + for (fitIntercept <- Seq(false, true)) { + val trainer = new GeneralizedLinearRegression().setFamily("poisson").setLink(link) + .setFitIntercept(fitIntercept) + val model = trainer.fit(dataset) + val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) + assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with poisson family, " + + s"$link link and fitIntercept = $fitIntercept.") + + val familyLink = new FamilyAndLink(Poisson, Link.fromName(link)) + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val eta = BLAS.dot(features, model.coefficients) + model.intercept + val prediction2 = familyLink.fitted(eta) + assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " + + s"poisson family, $link link and fitIntercept = $fitIntercept.") + } + + idx += 1 + } + } + } + + test("generalized linear regression: gamma family against glm") { + /* + R code: + f1 <- data$V1 ~ data$V2 + data$V3 - 1 + f2 <- data$V1 ~ data$V2 + data$V3 + + data <- read.csv("path", header=FALSE) + for (formula in c(f1, f2)) { + model <- glm(formula, family="Gamma", data=data) + print(as.vector(coef(model))) + } + + [1] 2.3392419 0.8058058 + [1] 2.3507700 2.2533574 0.6042991 + + data <- read.csv("path", header=FALSE) + for (formula in c(f1, f2)) { + model <- glm(formula, family=Gamma(link=identity), data=data) + print(as.vector(coef(model))) + } + + [1] 2.2908883 0.8147796 + [1] 2.5002406 2.1998346 0.6000059 + + data <- read.csv("path", header=FALSE) + for (formula in c(f1, f2)) { + model <- glm(formula, family=Gamma(link=log), data=data) + print(as.vector(coef(model))) + } + + [1] 0.22958970 0.08091066 + [1] 0.25003210 0.21996957 0.06000215 + */ + val expected = Seq( + Vectors.dense(0.0, 2.3392419, 0.8058058), + Vectors.dense(2.3507700, 2.2533574, 0.6042991), + Vectors.dense(0.0, 2.2908883, 0.8147796), + Vectors.dense(2.5002406, 2.1998346, 0.6000059), + Vectors.dense(0.0, 0.22958970, 0.08091066), + Vectors.dense(0.25003210, 0.21996957, 0.06000215)) + + import GeneralizedLinearRegression._ + + var idx = 0 + for ((link, dataset) <- Seq(("inverse", datasetGammaInverse), + ("identity", datasetGammaIdentity), ("log", datasetGammaLog))) { + for (fitIntercept <- Seq(false, true)) { + val trainer = new GeneralizedLinearRegression().setFamily("gamma").setLink(link) + .setFitIntercept(fitIntercept) + val model = trainer.fit(dataset) + val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) + assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with gamma family, " + + s"$link link and fitIntercept = $fitIntercept.") + + val familyLink = new FamilyAndLink(Gamma, Link.fromName(link)) + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val eta = BLAS.dot(features, model.coefficients) + model.intercept + val prediction2 = familyLink.fitted(eta) + assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " + + s"gamma family, $link link and fitIntercept = $fitIntercept.") + } + + idx += 1 + } + } + } +} + +object GeneralizedLinearRegressionSuite { + + def generateGeneralizedLinearRegressionInput( + intercept: Double, + coefficients: Array[Double], + xMean: Array[Double], + xVariance: Array[Double], + nPoints: Int, + seed: Int, + noiseLevel: Double, + family: String, + link: String): Seq[LabeledPoint] = { + + val rnd = new Random(seed) + def rndElement(i: Int) = { + (rnd.nextDouble() - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i) + } + val (generator, mean) = family match { + case "gaussian" => (new StandardNormalGenerator, 0.0) + case "poisson" => (new PoissonGenerator(1.0), 1.0) + case "gamma" => (new GammaGenerator(1.0, 1.0), 1.0) + } + generator.setSeed(seed) + + (0 until nPoints).map { _ => + val features = Vectors.dense(coefficients.indices.map { rndElement(_) }.toArray) + val eta = BLAS.dot(Vectors.dense(coefficients), features) + intercept + val mu = link match { + case "identity" => eta + case "log" => math.exp(eta) + case "sqrt" => math.pow(eta, 2.0) + case "inverse" => 1.0 / eta + } + val label = mu + noiseLevel * (generator.nextValue() - mean) + // Return LabeledPoints with DenseVector + LabeledPoint(label, features) + } + } +} From c37bbb3a1cbd93c749aaaeca1345817e0c20094f Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 1 Mar 2016 09:40:33 -0800 Subject: [PATCH 19/28] Closes #11320 Closes #10940 Closes #11302 Closes #11430 Closes #10912 From c27ba0d547a0cd3fd00bb42c76ad971b2d48b4a0 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 1 Mar 2016 13:07:04 -0800 Subject: [PATCH 20/28] [SPARK-13582] [SQL] defer dictionary decoding in parquet reader ## What changes were proposed in this pull request? This PR defer the resolution from a id of dictionary to value until the column is actually accessed (inside getInt/getLong), this is very useful for those columns and rows that are filtered out. It's also useful for binary type, we will not need to copy all the byte arrays. This PR also change the underlying type for small decimal that could be fit within a Int, in order to use getInt() to lookup the value from IntDictionary. ## How was this patch tested? Manually test TPCDS Q7 with scale factor 10, saw about 30% improvements (after PR #11274). Author: Davies Liu Closes #11437 from davies/decode_dict. --- .../org/apache/spark/sql/types/Decimal.scala | 3 + .../apache/spark/sql/types/DecimalType.scala | 11 ++ .../parquet/UnsafeRowParquetRecordReader.java | 101 +++++------------ .../parquet/VectorizedRleValuesReader.java | 33 ------ .../execution/vectorized/ColumnVector.java | 105 +++++++++++++++--- .../vectorized/ColumnVectorUtils.java | 8 +- .../execution/vectorized/ColumnarBatch.java | 24 +--- .../vectorized/OffHeapColumnVector.java | 58 ++++++---- .../vectorized/OnHeapColumnVector.java | 47 ++++++-- .../parquet/CatalystRowConverter.scala | 2 +- .../parquet/CatalystSchemaConverter.scala | 14 +-- .../parquet/CatalystWriteSupport.scala | 8 +- .../parquet/ParquetEncodingSuite.scala | 2 +- .../vectorized/ColumnarBatchBenchmark.scala | 2 +- .../vectorized/ColumnarBatchSuite.scala | 6 +- 15 files changed, 221 insertions(+), 203 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 38ce1604b1ed..6a59e9728a9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -340,6 +340,9 @@ object Decimal { val ROUND_CEILING = BigDecimal.RoundingMode.CEILING val ROUND_FLOOR = BigDecimal.RoundingMode.FLOOR + /** Maximum number of decimal digits a Int can represent */ + val MAX_INT_DIGITS = 9 + /** Maximum number of decimal digits a Long can represent */ val MAX_LONG_DIGITS = 18 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 2e03ddae760b..9c1319c1c5e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -150,6 +150,17 @@ object DecimalType extends AbstractDataType { } } + /** + * Returns if dt is a DecimalType that fits inside a int + */ + def is32BitDecimalType(dt: DataType): Boolean = { + dt match { + case t: DecimalType => + t.precision <= Decimal.MAX_INT_DIGITS + case _ => false + } + } + /** * Returns if dt is a DecimalType that fits inside a long */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java index e7f0ec2e7789..57dbd7c2ff56 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java @@ -257,8 +257,7 @@ private void initializeInternal() throws IOException { throw new IOException("Unsupported type: " + t); } if (originalTypes[i] == OriginalType.DECIMAL && - primitiveType.getDecimalMetadata().getPrecision() > - CatalystSchemaConverter.MAX_PRECISION_FOR_INT64()) { + primitiveType.getDecimalMetadata().getPrecision() > Decimal.MAX_LONG_DIGITS()) { throw new IOException("Decimal with high precision is not supported."); } if (primitiveType.getPrimitiveTypeName() == PrimitiveType.PrimitiveTypeName.INT96) { @@ -439,7 +438,7 @@ private void decodeFixedLenArrayAsDecimalBatch(int col, int num) throws IOExcept PrimitiveType type = requestedSchema.getFields().get(col).asPrimitiveType(); int precision = type.getDecimalMetadata().getPrecision(); int scale = type.getDecimalMetadata().getScale(); - Preconditions.checkState(precision <= CatalystSchemaConverter.MAX_PRECISION_FOR_INT64(), + Preconditions.checkState(precision <= Decimal.MAX_LONG_DIGITS(), "Unsupported precision."); for (int n = 0; n < num; ++n) { @@ -480,11 +479,6 @@ private final class ColumnReader { */ private boolean useDictionary; - /** - * If useDictionary is true, the staging vector used to decode the ids. - */ - private ColumnVector dictionaryIds; - /** * Maximum definition level for this column. */ @@ -620,18 +614,13 @@ private void readBatch(int total, ColumnVector column) throws IOException { } int num = Math.min(total, leftInPage); if (useDictionary) { - // Data is dictionary encoded. We will vector decode the ids and then resolve the values. - if (dictionaryIds == null) { - dictionaryIds = ColumnVector.allocate(total, DataTypes.IntegerType, MemoryMode.ON_HEAP); - } else { - dictionaryIds.reset(); - dictionaryIds.reserve(total); - } // Read and decode dictionary ids. + ColumnVector dictionaryIds = column.reserveDictionaryIds(total);; defColumn.readIntegers( num, dictionaryIds, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); - decodeDictionaryIds(rowId, num, column); + decodeDictionaryIds(rowId, num, column, dictionaryIds); } else { + column.setDictionary(null); switch (descriptor.getType()) { case BOOLEAN: readBooleanBatch(rowId, num, column); @@ -668,55 +657,25 @@ private void readBatch(int total, ColumnVector column) throws IOException { /** * Reads `num` values into column, decoding the values from `dictionaryIds` and `dictionary`. */ - private void decodeDictionaryIds(int rowId, int num, ColumnVector column) { + private void decodeDictionaryIds(int rowId, int num, ColumnVector column, + ColumnVector dictionaryIds) { switch (descriptor.getType()) { case INT32: - if (column.dataType() == DataTypes.IntegerType) { - for (int i = rowId; i < rowId + num; ++i) { - column.putInt(i, dictionary.decodeToInt(dictionaryIds.getInt(i))); - } - } else if (column.dataType() == DataTypes.ByteType) { - for (int i = rowId; i < rowId + num; ++i) { - column.putByte(i, (byte) dictionary.decodeToInt(dictionaryIds.getInt(i))); - } - } else if (column.dataType() == DataTypes.ShortType) { - for (int i = rowId; i < rowId + num; ++i) { - column.putShort(i, (short) dictionary.decodeToInt(dictionaryIds.getInt(i))); - } - } else if (DecimalType.is64BitDecimalType(column.dataType())) { - for (int i = rowId; i < rowId + num; ++i) { - column.putLong(i, dictionary.decodeToInt(dictionaryIds.getInt(i))); - } - } else { - throw new NotImplementedException("Unimplemented type: " + column.dataType()); - } - break; - case INT64: - if (column.dataType() == DataTypes.LongType || - DecimalType.is64BitDecimalType(column.dataType())) { - for (int i = rowId; i < rowId + num; ++i) { - column.putLong(i, dictionary.decodeToLong(dictionaryIds.getInt(i))); - } - } else { - throw new NotImplementedException("Unimplemented type: " + column.dataType()); - } - break; - case FLOAT: - for (int i = rowId; i < rowId + num; ++i) { - column.putFloat(i, dictionary.decodeToFloat(dictionaryIds.getInt(i))); - } - break; - case DOUBLE: - for (int i = rowId; i < rowId + num; ++i) { - column.putDouble(i, dictionary.decodeToDouble(dictionaryIds.getInt(i))); - } + case BINARY: + column.setDictionary(dictionary); break; case FIXED_LEN_BYTE_ARRAY: - if (DecimalType.is64BitDecimalType(column.dataType())) { + // DecimalType written in the legacy mode + if (DecimalType.is32BitDecimalType(column.dataType())) { + for (int i = rowId; i < rowId + num; ++i) { + Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i)); + column.putInt(i, (int) CatalystRowConverter.binaryToUnscaledLong(v)); + } + } else if (DecimalType.is64BitDecimalType(column.dataType())) { for (int i = rowId; i < rowId + num; ++i) { Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i)); column.putLong(i, CatalystRowConverter.binaryToUnscaledLong(v)); @@ -726,17 +685,6 @@ private void decodeDictionaryIds(int rowId, int num, ColumnVector column) { } break; - case BINARY: - // TODO: this is incredibly inefficient as it blows up the dictionary right here. We - // need to do this better. We should probably add the dictionary data to the ColumnVector - // and reuse it across batches. This should mean adding a ByteArray would just update - // the length and offset. - for (int i = rowId; i < rowId + num; ++i) { - Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i)); - column.putByteArray(i, v.getBytes()); - } - break; - default: throw new NotImplementedException("Unsupported type: " + descriptor.getType()); } @@ -756,15 +704,13 @@ private void readBooleanBatch(int rowId, int num, ColumnVector column) throws IO private void readIntBatch(int rowId, int num, ColumnVector column) throws IOException { // This is where we implement support for the valid type conversions. // TODO: implement remaining type conversions - if (column.dataType() == DataTypes.IntegerType || column.dataType() == DataTypes.DateType) { + if (column.dataType() == DataTypes.IntegerType || column.dataType() == DataTypes.DateType || + DecimalType.is32BitDecimalType(column.dataType())) { defColumn.readIntegers( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); } else if (column.dataType() == DataTypes.ByteType) { defColumn.readBytes( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); - } else if (DecimalType.is64BitDecimalType(column.dataType())) { - defColumn.readIntsAsLongs( - num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); } else if (column.dataType() == DataTypes.ShortType) { defColumn.readShorts( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); @@ -822,7 +768,16 @@ private void readFixedLenByteArrayBatch(int rowId, int num, VectorizedValuesReader data = (VectorizedValuesReader) dataColumn; // This is where we implement support for the valid type conversions. // TODO: implement remaining type conversions - if (DecimalType.is64BitDecimalType(column.dataType())) { + if (DecimalType.is32BitDecimalType(column.dataType())) { + for (int i = 0; i < num; i++) { + if (defColumn.readInteger() == maxDefLevel) { + column.putInt(rowId + i, + (int) CatalystRowConverter.binaryToUnscaledLong(data.readBinary(arrayLen))); + } else { + column.putNull(rowId + i); + } + } + } else if (DecimalType.is64BitDecimalType(column.dataType())) { for (int i = 0; i < num; i++) { if (defColumn.readInteger() == maxDefLevel) { column.putLong(rowId + i, diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java index 8613fcae0b80..62157389013b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java @@ -25,7 +25,6 @@ import org.apache.parquet.io.ParquetDecodingException; import org.apache.parquet.io.api.Binary; -import org.apache.spark.sql.Column; import org.apache.spark.sql.execution.vectorized.ColumnVector; /** @@ -239,38 +238,6 @@ public void readBooleans(int total, ColumnVector c, } } - public void readIntsAsLongs(int total, ColumnVector c, - int rowId, int level, VectorizedValuesReader data) { - int left = total; - while (left > 0) { - if (this.currentCount == 0) this.readNextGroup(); - int n = Math.min(left, this.currentCount); - switch (mode) { - case RLE: - if (currentValue == level) { - for (int i = 0; i < n; i++) { - c.putLong(rowId + i, data.readInteger()); - } - } else { - c.putNulls(rowId, n); - } - break; - case PACKED: - for (int i = 0; i < n; ++i) { - if (currentBuffer[currentBufferIdx++] == level) { - c.putLong(rowId + i, data.readInteger()); - } else { - c.putNull(rowId + i); - } - } - break; - } - rowId += n; - left -= n; - currentCount -= n; - } - } - public void readBytes(int total, ColumnVector c, int rowId, int level, VectorizedValuesReader data) { int left = total; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index 0514252a8e53..bb0247c2fbed 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -19,6 +19,10 @@ import java.math.BigDecimal; import java.math.BigInteger; +import org.apache.commons.lang.NotImplementedException; +import org.apache.parquet.column.Dictionary; +import org.apache.parquet.io.api.Binary; + import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.util.ArrayData; @@ -27,8 +31,6 @@ import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; -import org.apache.commons.lang.NotImplementedException; - /** * This class represents a column of values and provides the main APIs to access the data * values. It supports all the types and contains get/put APIs as well as their batched versions. @@ -157,7 +159,7 @@ public Object[] array() { } else if (dt instanceof StringType) { for (int i = 0; i < length; i++) { if (!data.getIsNull(offset + i)) { - list[i] = ColumnVectorUtils.toString(data.getByteArray(offset + i)); + list[i] = getUTF8String(i).toString(); } } } else if (dt instanceof CalendarIntervalType) { @@ -204,28 +206,17 @@ public float getFloat(int ordinal) { @Override public Decimal getDecimal(int ordinal, int precision, int scale) { - if (precision <= Decimal.MAX_LONG_DIGITS()) { - return Decimal.apply(getLong(ordinal), precision, scale); - } else { - byte[] bytes = getBinary(ordinal); - BigInteger bigInteger = new BigInteger(bytes); - BigDecimal javaDecimal = new BigDecimal(bigInteger, scale); - return Decimal.apply(javaDecimal, precision, scale); - } + return data.getDecimal(offset + ordinal, precision, scale); } @Override public UTF8String getUTF8String(int ordinal) { - Array child = data.getByteArray(offset + ordinal); - return UTF8String.fromBytes(child.byteArray, child.byteArrayOffset, child.length); + return data.getUTF8String(offset + ordinal); } @Override public byte[] getBinary(int ordinal) { - ColumnVector.Array array = data.getByteArray(offset + ordinal); - byte[] bytes = new byte[array.length]; - System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length); - return bytes; + return data.getBinary(offset + ordinal); } @Override @@ -534,12 +525,57 @@ public final int putByteArray(int rowId, byte[] value) { /** * Returns the value for rowId. */ - public final Array getByteArray(int rowId) { + private Array getByteArray(int rowId) { Array array = getArray(rowId); array.data.loadBytes(array); return array; } + /** + * Returns the decimal for rowId. + */ + public final Decimal getDecimal(int rowId, int precision, int scale) { + if (precision <= Decimal.MAX_INT_DIGITS()) { + return Decimal.apply(getInt(rowId), precision, scale); + } else if (precision <= Decimal.MAX_LONG_DIGITS()) { + return Decimal.apply(getLong(rowId), precision, scale); + } else { + // TODO: best perf? + byte[] bytes = getBinary(rowId); + BigInteger bigInteger = new BigInteger(bytes); + BigDecimal javaDecimal = new BigDecimal(bigInteger, scale); + return Decimal.apply(javaDecimal, precision, scale); + } + } + + /** + * Returns the UTF8String for rowId. + */ + public final UTF8String getUTF8String(int rowId) { + if (dictionary == null) { + ColumnVector.Array a = getByteArray(rowId); + return UTF8String.fromBytes(a.byteArray, a.byteArrayOffset, a.length); + } else { + Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(rowId)); + return UTF8String.fromBytes(v.getBytes()); + } + } + + /** + * Returns the byte array for rowId. + */ + public final byte[] getBinary(int rowId) { + if (dictionary == null) { + ColumnVector.Array array = getByteArray(rowId); + byte[] bytes = new byte[array.length]; + System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length); + return bytes; + } else { + Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(rowId)); + return v.getBytes(); + } + } + /** * Append APIs. These APIs all behave similarly and will append data to the current vector. It * is not valid to mix the put and append APIs. The append APIs are slower and should only be @@ -816,6 +852,39 @@ public final int appendStruct(boolean isNull) { */ protected final ColumnarBatch.Row resultStruct; + /** + * The Dictionary for this column. + * + * If it's not null, will be used to decode the value in getXXX(). + */ + protected Dictionary dictionary; + + /** + * Reusable column for ids of dictionary. + */ + protected ColumnVector dictionaryIds; + + /** + * Update the dictionary. + */ + public void setDictionary(Dictionary dictionary) { + this.dictionary = dictionary; + } + + /** + * Reserve a integer column for ids of dictionary. + */ + public ColumnVector reserveDictionaryIds(int capacity) { + if (dictionaryIds == null) { + dictionaryIds = allocate(capacity, DataTypes.IntegerType, + this instanceof OnHeapColumnVector ? MemoryMode.ON_HEAP : MemoryMode.OFF_HEAP); + } else { + dictionaryIds.reset(); + dictionaryIds.reserve(capacity); + } + return dictionaryIds; + } + /** * Sets up the common state and also handles creating the child columns if this is a nested * type. diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index 2aeef7f2f90f..681ace338713 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -22,24 +22,20 @@ import java.util.Iterator; import java.util.List; +import org.apache.commons.lang.NotImplementedException; + import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.Row; import org.apache.spark.sql.catalyst.util.DateTimeUtils; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.CalendarInterval; -import org.apache.commons.lang.NotImplementedException; - /** * Utilities to help manipulate data associate with ColumnVectors. These should be used mostly * for debugging or other non-performance critical paths. * These utilities are mostly used to convert ColumnVectors into other formats. */ public class ColumnVectorUtils { - public static String toString(ColumnVector.Array a) { - return new String(a.byteArray, a.byteArrayOffset, a.length); - } - /** * Returns the array data as the java primitive array. * For example, an array of IntegerType will return an int[]. diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index 070d897a7158..8a0d7f8b1237 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -16,11 +16,11 @@ */ package org.apache.spark.sql.execution.vectorized; -import java.math.BigDecimal; -import java.math.BigInteger; import java.util.Arrays; import java.util.Iterator; +import org.apache.commons.lang.NotImplementedException; + import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericMutableRow; @@ -31,8 +31,6 @@ import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; -import org.apache.commons.lang.NotImplementedException; - /** * This class is the in memory representation of rows as they are streamed through operators. It * is designed to maximize CPU efficiency and not storage footprint. Since it is expected that @@ -193,29 +191,17 @@ public final boolean anyNull() { @Override public final Decimal getDecimal(int ordinal, int precision, int scale) { - if (precision <= Decimal.MAX_LONG_DIGITS()) { - return Decimal.apply(getLong(ordinal), precision, scale); - } else { - // TODO: best perf? - byte[] bytes = getBinary(ordinal); - BigInteger bigInteger = new BigInteger(bytes); - BigDecimal javaDecimal = new BigDecimal(bigInteger, scale); - return Decimal.apply(javaDecimal, precision, scale); - } + return columns[ordinal].getDecimal(rowId, precision, scale); } @Override public final UTF8String getUTF8String(int ordinal) { - ColumnVector.Array a = columns[ordinal].getByteArray(rowId); - return UTF8String.fromBytes(a.byteArray, a.byteArrayOffset, a.length); + return columns[ordinal].getUTF8String(rowId); } @Override public final byte[] getBinary(int ordinal) { - ColumnVector.Array array = columns[ordinal].getByteArray(rowId); - byte[] bytes = new byte[array.length]; - System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length); - return bytes; + return columns[ordinal].getBinary(rowId); } @Override diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index e38ed051219b..b06b7f2457b5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -18,25 +18,11 @@ import java.nio.ByteOrder; -import org.apache.spark.memory.MemoryMode; -import org.apache.spark.sql.execution.vectorized.ColumnVector.Array; -import org.apache.spark.sql.types.BooleanType; -import org.apache.spark.sql.types.ByteType; -import org.apache.spark.sql.types.DataType; -import org.apache.spark.sql.types.DateType; -import org.apache.spark.sql.types.DecimalType; -import org.apache.spark.sql.types.DoubleType; -import org.apache.spark.sql.types.FloatType; -import org.apache.spark.sql.types.IntegerType; -import org.apache.spark.sql.types.LongType; -import org.apache.spark.sql.types.ShortType; -import org.apache.spark.unsafe.Platform; -import org.apache.spark.unsafe.types.UTF8String; - - import org.apache.commons.lang.NotImplementedException; -import org.apache.commons.lang.NotImplementedException; +import org.apache.spark.memory.MemoryMode; +import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.Platform; /** * Column data backed using offheap memory. @@ -171,7 +157,11 @@ public final void putBytes(int rowId, int count, byte[] src, int srcIndex) { @Override public final byte getByte(int rowId) { - return Platform.getByte(null, data + rowId); + if (dictionary == null) { + return Platform.getByte(null, data + rowId); + } else { + return (byte) dictionary.decodeToInt(dictionaryIds.getInt(rowId)); + } } // @@ -199,7 +189,11 @@ public final void putShorts(int rowId, int count, short[] src, int srcIndex) { @Override public final short getShort(int rowId) { - return Platform.getShort(null, data + 2 * rowId); + if (dictionary == null) { + return Platform.getShort(null, data + 2 * rowId); + } else { + return (short) dictionary.decodeToInt(dictionaryIds.getInt(rowId)); + } } // @@ -233,7 +227,11 @@ public final void putIntsLittleEndian(int rowId, int count, byte[] src, int srcI @Override public final int getInt(int rowId) { - return Platform.getInt(null, data + 4 * rowId); + if (dictionary == null) { + return Platform.getInt(null, data + 4 * rowId); + } else { + return dictionary.decodeToInt(dictionaryIds.getInt(rowId)); + } } // @@ -267,7 +265,11 @@ public final void putLongsLittleEndian(int rowId, int count, byte[] src, int src @Override public final long getLong(int rowId) { - return Platform.getLong(null, data + 8 * rowId); + if (dictionary == null) { + return Platform.getLong(null, data + 8 * rowId); + } else { + return dictionary.decodeToLong(dictionaryIds.getInt(rowId)); + } } // @@ -301,7 +303,11 @@ public final void putFloats(int rowId, int count, byte[] src, int srcIndex) { @Override public final float getFloat(int rowId) { - return Platform.getFloat(null, data + rowId * 4); + if (dictionary == null) { + return Platform.getFloat(null, data + rowId * 4); + } else { + return dictionary.decodeToFloat(dictionaryIds.getInt(rowId)); + } } @@ -336,7 +342,11 @@ public final void putDoubles(int rowId, int count, byte[] src, int srcIndex) { @Override public final double getDouble(int rowId) { - return Platform.getDouble(null, data + rowId * 8); + if (dictionary == null) { + return Platform.getDouble(null, data + rowId * 8); + } else { + return dictionary.decodeToDouble(dictionaryIds.getInt(rowId)); + } } // @@ -394,7 +404,7 @@ private final void reserveInternal(int newCapacity) { } else if (type instanceof ShortType) { this.data = Platform.reallocateMemory(data, elementsAppended * 2, newCapacity * 2); } else if (type instanceof IntegerType || type instanceof FloatType || - type instanceof DateType) { + type instanceof DateType || DecimalType.is32BitDecimalType(type)) { this.data = Platform.reallocateMemory(data, elementsAppended * 4, newCapacity * 4); } else if (type instanceof LongType || type instanceof DoubleType || DecimalType.is64BitDecimalType(type)) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 3502d31bd1df..305e84a86bdc 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -16,13 +16,12 @@ */ package org.apache.spark.sql.execution.vectorized; +import java.util.Arrays; + import org.apache.spark.memory.MemoryMode; -import org.apache.spark.sql.execution.vectorized.ColumnVector.Array; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.Platform; -import java.util.Arrays; - /** * A column backed by an in memory JVM array. This stores the NULLs as a byte per value * and a java array for the values. @@ -68,7 +67,6 @@ public final void close() { doubleData = null; } - // // APIs dealing with nulls // @@ -154,7 +152,11 @@ public final void putBytes(int rowId, int count, byte[] src, int srcIndex) { @Override public final byte getByte(int rowId) { - return byteData[rowId]; + if (dictionary == null) { + return byteData[rowId]; + } else { + return (byte) dictionary.decodeToInt(dictionaryIds.getInt(rowId)); + } } // @@ -180,7 +182,11 @@ public final void putShorts(int rowId, int count, short[] src, int srcIndex) { @Override public final short getShort(int rowId) { - return shortData[rowId]; + if (dictionary == null) { + return shortData[rowId]; + } else { + return (short) dictionary.decodeToInt(dictionaryIds.getInt(rowId)); + } } @@ -217,7 +223,11 @@ public final void putIntsLittleEndian(int rowId, int count, byte[] src, int srcI @Override public final int getInt(int rowId) { - return intData[rowId]; + if (dictionary == null) { + return intData[rowId]; + } else { + return dictionary.decodeToInt(dictionaryIds.getInt(rowId)); + } } // @@ -253,7 +263,11 @@ public final void putLongsLittleEndian(int rowId, int count, byte[] src, int src @Override public final long getLong(int rowId) { - return longData[rowId]; + if (dictionary == null) { + return longData[rowId]; + } else { + return dictionary.decodeToLong(dictionaryIds.getInt(rowId)); + } } // @@ -280,7 +294,13 @@ public final void putFloats(int rowId, int count, byte[] src, int srcIndex) { } @Override - public final float getFloat(int rowId) { return floatData[rowId]; } + public final float getFloat(int rowId) { + if (dictionary == null) { + return floatData[rowId]; + } else { + return dictionary.decodeToFloat(dictionaryIds.getInt(rowId)); + } + } // // APIs dealing with doubles @@ -309,7 +329,11 @@ public final void putDoubles(int rowId, int count, byte[] src, int srcIndex) { @Override public final double getDouble(int rowId) { - return doubleData[rowId]; + if (dictionary == null) { + return doubleData[rowId]; + } else { + return dictionary.decodeToDouble(dictionaryIds.getInt(rowId)); + } } // @@ -377,7 +401,8 @@ private final void reserveInternal(int newCapacity) { short[] newData = new short[newCapacity]; if (shortData != null) System.arraycopy(shortData, 0, newData, 0, elementsAppended); shortData = newData; - } else if (type instanceof IntegerType || type instanceof DateType) { + } else if (type instanceof IntegerType || type instanceof DateType || + DecimalType.is32BitDecimalType(type)) { int[] newData = new int[newCapacity]; if (intData != null) System.arraycopy(intData, 0, newData, 0, elementsAppended); intData = newData; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala index 42d89f4bf81d..8a128b4b6176 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala @@ -368,7 +368,7 @@ private[parquet] class CatalystRowConverter( } protected def decimalFromBinary(value: Binary): Decimal = { - if (precision <= CatalystSchemaConverter.MAX_PRECISION_FOR_INT64) { + if (precision <= Decimal.MAX_LONG_DIGITS) { // Constructs a `Decimal` with an unscaled `Long` value if possible. val unscaled = CatalystRowConverter.binaryToUnscaledLong(value) Decimal(unscaled, precision, scale) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala index ab4250d0adba..6f6340f541ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala @@ -26,7 +26,7 @@ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ import org.apache.parquet.schema.Type.Repetition._ import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.{maxPrecisionForBytes, MAX_PRECISION_FOR_INT32, MAX_PRECISION_FOR_INT64} +import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.maxPrecisionForBytes import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -145,7 +145,7 @@ private[parquet] class CatalystSchemaConverter( case INT_16 => ShortType case INT_32 | null => IntegerType case DATE => DateType - case DECIMAL => makeDecimalType(MAX_PRECISION_FOR_INT32) + case DECIMAL => makeDecimalType(Decimal.MAX_INT_DIGITS) case UINT_8 => typeNotSupported() case UINT_16 => typeNotSupported() case UINT_32 => typeNotSupported() @@ -156,7 +156,7 @@ private[parquet] class CatalystSchemaConverter( case INT64 => originalType match { case INT_64 | null => LongType - case DECIMAL => makeDecimalType(MAX_PRECISION_FOR_INT64) + case DECIMAL => makeDecimalType(Decimal.MAX_LONG_DIGITS) case UINT_64 => typeNotSupported() case TIMESTAMP_MILLIS => typeNotImplemented() case _ => illegalType() @@ -403,7 +403,7 @@ private[parquet] class CatalystSchemaConverter( // Uses INT32 for 1 <= precision <= 9 case DecimalType.Fixed(precision, scale) - if precision <= MAX_PRECISION_FOR_INT32 && !writeLegacyParquetFormat => + if precision <= Decimal.MAX_INT_DIGITS && !writeLegacyParquetFormat => Types .primitive(INT32, repetition) .as(DECIMAL) @@ -413,7 +413,7 @@ private[parquet] class CatalystSchemaConverter( // Uses INT64 for 1 <= precision <= 18 case DecimalType.Fixed(precision, scale) - if precision <= MAX_PRECISION_FOR_INT64 && !writeLegacyParquetFormat => + if precision <= Decimal.MAX_LONG_DIGITS && !writeLegacyParquetFormat => Types .primitive(INT64, repetition) .as(DECIMAL) @@ -569,10 +569,6 @@ private[parquet] object CatalystSchemaConverter { // Returns the minimum number of bytes needed to store a decimal with a given `precision`. val minBytesForPrecision = Array.tabulate[Int](39)(computeMinBytesForPrecision) - val MAX_PRECISION_FOR_INT32 = maxPrecisionForBytes(4) /* 9 */ - - val MAX_PRECISION_FOR_INT64 = maxPrecisionForBytes(8) /* 18 */ - // Max precision of a decimal value stored in `numBytes` bytes def maxPrecisionForBytes(numBytes: Int): Int = { Math.round( // convert double to long diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala index 3508220c9541..0252c79d8e14 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala @@ -33,7 +33,7 @@ import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecializedGetters import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.{minBytesForPrecision, MAX_PRECISION_FOR_INT32, MAX_PRECISION_FOR_INT64} +import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.minBytesForPrecision import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -253,13 +253,13 @@ private[parquet] class CatalystWriteSupport extends WriteSupport[InternalRow] wi writeLegacyParquetFormat match { // Standard mode, 1 <= precision <= 9, writes as INT32 - case false if precision <= MAX_PRECISION_FOR_INT32 => int32Writer + case false if precision <= Decimal.MAX_INT_DIGITS => int32Writer // Standard mode, 10 <= precision <= 18, writes as INT64 - case false if precision <= MAX_PRECISION_FOR_INT64 => int64Writer + case false if precision <= Decimal.MAX_LONG_DIGITS => int64Writer // Legacy mode, 1 <= precision <= 18, writes as FIXED_LEN_BYTE_ARRAY - case true if precision <= MAX_PRECISION_FOR_INT64 => binaryWriterUsingUnscaledLong + case true if precision <= Decimal.MAX_LONG_DIGITS => binaryWriterUsingUnscaledLong // Either standard or legacy mode, 19 <= precision <= 38, writes as FIXED_LEN_BYTE_ARRAY case _ => binaryWriterUsingUnscaledBytes diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala index cef6b79a094d..281a2cffa894 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala @@ -47,7 +47,7 @@ class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSQLContex assert(batch.column(0).getByte(i) == 1) assert(batch.column(1).getInt(i) == 2) assert(batch.column(2).getLong(i) == 3) - assert(ColumnVectorUtils.toString(batch.column(3).getByteArray(i)) == "abc") + assert(batch.column(3).getUTF8String(i).toString == "abc") i += 1 } reader.close() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala index 8efdf8adb042..97638a66ab47 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala @@ -370,7 +370,7 @@ object ColumnarBatchBenchmark { } i = 0 while (i < count) { - sum += column.getByteArray(i).length + sum += column.getUTF8String(i).numBytes() i += 1 } column.reset() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 445f311107e3..b3c3e66fbcbd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -360,7 +360,7 @@ class ColumnarBatchSuite extends SparkFunSuite { reference.zipWithIndex.foreach { v => assert(v._1.length == column.getArrayLength(v._2), "MemoryMode=" + memMode) - assert(v._1 == ColumnVectorUtils.toString(column.getByteArray(v._2)), + assert(v._1 == column.getUTF8String(v._2).toString, "MemoryMode" + memMode) } @@ -488,7 +488,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(batch.column(1).getDouble(0) == 1.1) assert(batch.column(1).getIsNull(0) == false) assert(batch.column(2).getIsNull(0) == true) - assert(ColumnVectorUtils.toString(batch.column(3).getByteArray(0)) == "Hello") + assert(batch.column(3).getUTF8String(0).toString == "Hello") // Verify the iterator works correctly. val it = batch.rowIterator() @@ -499,7 +499,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(row.getDouble(1) == 1.1) assert(row.isNullAt(1) == false) assert(row.isNullAt(2) == true) - assert(ColumnVectorUtils.toString(batch.column(3).getByteArray(0)) == "Hello") + assert(batch.column(3).getUTF8String(0).toString == "Hello") assert(it.hasNext == false) assert(it.hasNext == false) From b0ee7d43730469ad61fdf6b7b75cc1b1efb62c31 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 1 Mar 2016 15:39:13 -0800 Subject: [PATCH 21/28] [SPARK-13548][BUILD] Move tags and unsafe modules into common ## What changes were proposed in this pull request? This patch moves tags and unsafe modules into common directory to remove 2 top level non-user-facing directories. ## How was this patch tested? Jenkins should suffice. Author: Reynold Xin Closes #11426 from rxin/SPARK-13548. --- {tags => common/tags}/README.md | 0 {tags => common/tags}/pom.xml | 2 +- .../tags}/src/main/java/org/apache/spark/tags/DockerTest.java | 0 .../src/main/java/org/apache/spark/tags/ExtendedHiveTest.java | 0 .../src/main/java/org/apache/spark/tags/ExtendedYarnTest.java | 0 {unsafe => common/unsafe}/pom.xml | 2 +- .../src/main/java/org/apache/spark/unsafe/KVIterator.java | 0 .../src/main/java/org/apache/spark/unsafe/Platform.java | 0 .../java/org/apache/spark/unsafe/array/ByteArrayMethods.java | 0 .../main/java/org/apache/spark/unsafe/array/LongArray.java | 0 .../java/org/apache/spark/unsafe/bitset/BitSetMethods.java | 0 .../java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java | 0 .../org/apache/spark/unsafe/memory/HeapMemoryAllocator.java | 0 .../java/org/apache/spark/unsafe/memory/MemoryAllocator.java | 0 .../main/java/org/apache/spark/unsafe/memory/MemoryBlock.java | 0 .../java/org/apache/spark/unsafe/memory/MemoryLocation.java | 0 .../org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java | 0 .../main/java/org/apache/spark/unsafe/types/ByteArray.java | 0 .../java/org/apache/spark/unsafe/types/CalendarInterval.java | 0 .../main/java/org/apache/spark/unsafe/types/UTF8String.java | 0 .../test/java/org/apache/spark/unsafe/PlatformUtilSuite.java | 0 .../java/org/apache/spark/unsafe/array/LongArraySuite.java | 0 .../org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java | 0 .../org/apache/spark/unsafe/types/CalendarIntervalSuite.java | 0 .../java/org/apache/spark/unsafe/types/UTF8StringSuite.java | 0 .../spark/unsafe/types/UTF8StringPropertyCheckSuite.scala | 0 pom.xml | 4 ++-- 27 files changed, 4 insertions(+), 4 deletions(-) rename {tags => common/tags}/README.md (100%) rename {tags => common/tags}/pom.xml (97%) rename {tags => common/tags}/src/main/java/org/apache/spark/tags/DockerTest.java (100%) rename {tags => common/tags}/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java (100%) rename {tags => common/tags}/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java (100%) rename {unsafe => common/unsafe}/pom.xml (98%) rename {unsafe => common/unsafe}/src/main/java/org/apache/spark/unsafe/KVIterator.java (100%) rename {unsafe => common/unsafe}/src/main/java/org/apache/spark/unsafe/Platform.java (100%) rename {unsafe => common/unsafe}/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java (100%) rename {unsafe => common/unsafe}/src/main/java/org/apache/spark/unsafe/array/LongArray.java (100%) rename {unsafe => common/unsafe}/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java (100%) rename {unsafe => common/unsafe}/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java (100%) rename {unsafe => common/unsafe}/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java (100%) rename {unsafe => common/unsafe}/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java (100%) rename {unsafe => common/unsafe}/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java (100%) rename {unsafe => common/unsafe}/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java (100%) rename {unsafe => common/unsafe}/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java (100%) rename {unsafe => common/unsafe}/src/main/java/org/apache/spark/unsafe/types/ByteArray.java (100%) rename {unsafe => common/unsafe}/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java (100%) rename {unsafe => common/unsafe}/src/main/java/org/apache/spark/unsafe/types/UTF8String.java (100%) rename {unsafe => common/unsafe}/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java (100%) rename {unsafe => common/unsafe}/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java (100%) rename {unsafe => common/unsafe}/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java (100%) rename {unsafe => common/unsafe}/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java (100%) rename {unsafe => common/unsafe}/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java (100%) rename {unsafe => common/unsafe}/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala (100%) diff --git a/tags/README.md b/common/tags/README.md similarity index 100% rename from tags/README.md rename to common/tags/README.md diff --git a/tags/pom.xml b/common/tags/pom.xml similarity index 97% rename from tags/pom.xml rename to common/tags/pom.xml index 3e8e6f618287..8e702b4fefe8 100644 --- a/tags/pom.xml +++ b/common/tags/pom.xml @@ -23,7 +23,7 @@ org.apache.spark spark-parent_2.11 2.0.0-SNAPSHOT - ../pom.xml + ../../pom.xml org.apache.spark diff --git a/tags/src/main/java/org/apache/spark/tags/DockerTest.java b/common/tags/src/main/java/org/apache/spark/tags/DockerTest.java similarity index 100% rename from tags/src/main/java/org/apache/spark/tags/DockerTest.java rename to common/tags/src/main/java/org/apache/spark/tags/DockerTest.java diff --git a/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java b/common/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java similarity index 100% rename from tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java rename to common/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java diff --git a/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java b/common/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java similarity index 100% rename from tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java rename to common/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java diff --git a/unsafe/pom.xml b/common/unsafe/pom.xml similarity index 98% rename from unsafe/pom.xml rename to common/unsafe/pom.xml index 75fea556eeae..5250014739da 100644 --- a/unsafe/pom.xml +++ b/common/unsafe/pom.xml @@ -23,7 +23,7 @@ org.apache.spark spark-parent_2.11 2.0.0-SNAPSHOT - ../pom.xml + ../../pom.xml org.apache.spark diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/Platform.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java similarity index 100% rename from unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java rename to common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java similarity index 100% rename from unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java rename to common/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java similarity index 100% rename from unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java rename to common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java similarity index 100% rename from unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java rename to common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java similarity index 100% rename from unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java rename to common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java diff --git a/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala similarity index 100% rename from unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala rename to common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala diff --git a/pom.xml b/pom.xml index 2376e307ced1..2148379896d3 100644 --- a/pom.xml +++ b/pom.xml @@ -89,7 +89,8 @@ common/sketch common/network-common common/network-shuffle - tags + common/unsafe + common/tags core graphx mllib @@ -99,7 +100,6 @@ sql/core sql/hive docker-integration-tests - unsafe assembly external/twitter external/flume From a640c5b4fbd653919a5897a7b11f16328f2094eb Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 1 Mar 2016 17:27:57 -0800 Subject: [PATCH 22/28] [SPARK-13598] [SQL] remove LeftSemiJoinBNL ## What changes were proposed in this pull request? Broadcast left semi join without joining keys is already supported in BroadcastNestedLoopJoin, it has the same implementation as LeftSemiJoinBNL, we should remove that. ## How was this patch tested? Updated unit tests. Author: Davies Liu Closes #11448 from davies/remove_bnl. --- .../spark/sql/execution/SparkStrategies.scala | 3 - .../sql/execution/joins/LeftSemiJoinBNL.scala | 80 ------------------- .../org/apache/spark/sql/JoinSuite.scala | 5 +- .../sql/execution/joins/SemiJoinSuite.scala | 9 --- 4 files changed, 2 insertions(+), 95 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index dd8c96d5fa1d..0255103b63d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -71,9 +71,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) => joins.LeftSemiJoinHash( leftKeys, rightKeys, planLater(left), planLater(right), condition) :: Nil - // no predicate can be evaluated by matching hash keys - case logical.Join(left, right, LeftSemi, condition) => - joins.LeftSemiJoinBNL(planLater(left), planLater(right), condition) :: Nil case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala deleted file mode 100644 index df6dac88187c..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.metric.SQLMetrics - -/** - * Using BroadcastNestedLoopJoin to calculate left semi join result when there's no join keys - * for hash join. - */ -case class LeftSemiJoinBNL( - streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression]) extends BinaryNode { - - override private[sql] lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - - override def outputPartitioning: Partitioning = streamed.outputPartitioning - - override def output: Seq[Attribute] = left.output - - /** The Streamed Relation */ - override def left: SparkPlan = streamed - - /** The Broadcast relation */ - override def right: SparkPlan = broadcast - - override def requiredChildDistribution: Seq[Distribution] = { - UnspecifiedDistribution :: BroadcastDistribution(IdentityBroadcastMode) :: Nil - } - - @transient private lazy val boundCondition = - newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) - - protected override def doExecute(): RDD[InternalRow] = { - val numOutputRows = longMetric("numOutputRows") - - val broadcastedRelation = broadcast.executeBroadcast[Array[InternalRow]]() - - streamed.execute().mapPartitions { streamedIter => - val joinedRow = new JoinedRow - val relation = broadcastedRelation.value - - streamedIter.filter(streamedRow => { - var i = 0 - var matched = false - - while (i < relation.length && !matched) { - if (boundCondition(joinedRow(streamedRow, relation(i)))) { - matched = true - } - i += 1 - } - if (matched) { - numOutputRows += 1 - } - matched - }) - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 3dab848e7b03..5b98c11ef2a4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -47,7 +47,6 @@ class JoinSuite extends QueryTest with SharedSQLContext { val operators = physical.collect { case j: LeftSemiJoinHash => j case j: BroadcastHashJoin => j - case j: LeftSemiJoinBNL => j case j: CartesianProduct => j case j: BroadcastNestedLoopJoin => j case j: BroadcastLeftSemiJoinHash => j @@ -67,7 +66,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "0") { Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]), - ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]), + ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[BroadcastNestedLoopJoin]), ("SELECT * FROM testData JOIN testData2", classOf[CartesianProduct]), ("SELECT * FROM testData JOIN testData2 WHERE key = 2", classOf[CartesianProduct]), ("SELECT * FROM testData LEFT JOIN testData2", classOf[BroadcastNestedLoopJoin]), @@ -465,7 +464,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]), ("SELECT * FROM testData LEFT SEMI JOIN testData2", - classOf[LeftSemiJoinBNL]), + classOf[BroadcastNestedLoopJoin]), ("SELECT * FROM testData JOIN testData2", classOf[BroadcastNestedLoopJoin]), ("SELECT * FROM testData JOIN testData2 WHERE key = 2", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala index 355f916a9755..bc341db5571b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala @@ -95,15 +95,6 @@ class SemiJoinSuite extends SparkPlanTest with SharedSQLContext { } } - test(s"$testName using LeftSemiJoinBNL") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - LeftSemiJoinBNL(left, right, Some(condition)), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } - } - test(s"$testName using BroadcastNestedLoopJoin build left") { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => From e42724b12b976b3276accc1132f446fa67f7f981 Mon Sep 17 00:00:00 2001 From: sureshthalamati Date: Tue, 1 Mar 2016 17:34:21 -0800 Subject: [PATCH 23/28] [SPARK-13167][SQL] Include rows with null values for partition column when reading from JDBC datasources. Rows with null values in partition column are not included in the results because none of the partition where clause specify is null predicate on the partition column. This fix adds is null predicate on the partition column to the first JDBC partition where clause. Example: JDBCPartition(THEID < 1 or THEID is null, 0),JDBCPartition(THEID >= 1 AND THEID < 2,1), JDBCPartition(THEID >= 2, 2) Author: sureshthalamati Closes #11063 from sureshthalamati/nullable_jdbc_part_col_spark-13167. --- .../datasources/jdbc/JDBCRelation.scala | 8 +++- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 38 +++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index ee6373d03e1f..9e336422d1f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -44,6 +44,12 @@ private[sql] object JDBCRelation { * exactly once. The parameters minValue and maxValue are advisory in that * incorrect values may cause the partitioning to be poor, but no data * will fail to be represented. + * + * Null value predicate is added to the first partition where clause to include + * the rows with null value for the partitions column. + * + * @param partitioning partition information to generate the where clause for each partition + * @return an array of partitions with where clause for each partition */ def columnPartition(partitioning: JDBCPartitioningInfo): Array[Partition] = { if (partitioning == null) return Array[Partition](JDBCPartition(null, 0)) @@ -66,7 +72,7 @@ private[sql] object JDBCRelation { if (upperBound == null) { lowerBound } else if (lowerBound == null) { - upperBound + s"$upperBound or $column is null" } else { s"$lowerBound AND $upperBound" } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index f8a9a95c873a..30a5e2ea4acd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -171,6 +171,27 @@ class JDBCSuite extends SparkFunSuite |OPTIONS (url '$url', dbtable 'TEST.NULLTYPES', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) + conn.prepareStatement( + "create table test.emp(name TEXT(32) NOT NULL," + + " theid INTEGER, \"Dept\" INTEGER)").executeUpdate() + conn.prepareStatement( + "insert into test.emp values ('fred', 1, 10)").executeUpdate() + conn.prepareStatement( + "insert into test.emp values ('mary', 2, null)").executeUpdate() + conn.prepareStatement( + "insert into test.emp values ('joe ''foo'' \"bar\"', 3, 30)").executeUpdate() + conn.prepareStatement( + "insert into test.emp values ('kathy', null, null)").executeUpdate() + conn.commit() + + sql( + s""" + |CREATE TEMPORARY TABLE nullparts + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url', dbtable 'TEST.EMP', user 'testUser', password 'testPass', + |partitionColumn '"Dept"', lowerBound '1', upperBound '4', numPartitions '4') + """.stripMargin.replaceAll("\n", " ")) + // Untested: IDENTITY, OTHER, UUID, ARRAY, and GEOMETRY types. } @@ -338,6 +359,23 @@ class JDBCSuite extends SparkFunSuite .collect().length === 3) } + test("Partioning on column that might have null values.") { + assert( + sqlContext.read.jdbc(urlWithUserAndPass, "TEST.EMP", "theid", 0, 4, 3, new Properties) + .collect().length === 4) + assert( + sqlContext.read.jdbc(urlWithUserAndPass, "TEST.EMP", "THEID", 0, 4, 3, new Properties) + .collect().length === 4) + // partitioning on a nullable quoted column + assert( + sqlContext.read.jdbc(urlWithUserAndPass, "TEST.EMP", """"Dept"""", 0, 4, 3, new Properties) + .collect().length === 4) + } + + test("SELECT * on partitioned table with a nullable partioncolumn") { + assert(sql("SELECT * FROM nullparts").collect().size == 4) + } + test("H2 integral types") { val rows = sql("SELECT * FROM inttypes WHERE A IS NOT NULL").collect() assert(rows.length === 1) From 9495c40f227b785d852abdc307461d2e7e5c2011 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 1 Mar 2016 21:26:47 -0800 Subject: [PATCH 24/28] [SPARK-13008][ML][PYTHON] Put one alg per line in pyspark.ml all lists This is to fix a long-time annoyance: Whenever we add a new algorithm to pyspark.ml, we have to add it to the ```__all__``` list at the top. Since we keep it alphabetized, it often creates a lot more changes than needed. It is also easy to add the Estimator and forget the Model. I'm going to switch it to have one algorithm per line. This also alphabetizes a few out-of-place classes in pyspark.ml.feature. No changes have been made to the moved classes. CC: thunterdb Author: Joseph K. Bradley Closes #10927 from jkbradley/ml-python-all-list. --- python/pyspark/ml/classification.py | 11 +++++---- python/pyspark/ml/clustering.py | 3 ++- python/pyspark/ml/feature.py | 37 ++++++++++++++++++++++------- 3 files changed, 36 insertions(+), 15 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 3179fb30ab4d..253af15cb5cd 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -26,11 +26,12 @@ from pyspark.mllib.common import inherit_doc -__all__ = ['LogisticRegression', 'LogisticRegressionModel', 'DecisionTreeClassifier', - 'DecisionTreeClassificationModel', 'GBTClassifier', 'GBTClassificationModel', - 'RandomForestClassifier', 'RandomForestClassificationModel', 'NaiveBayes', - 'NaiveBayesModel', 'MultilayerPerceptronClassifier', - 'MultilayerPerceptronClassificationModel'] +__all__ = ['LogisticRegression', 'LogisticRegressionModel', + 'DecisionTreeClassifier', 'DecisionTreeClassificationModel', + 'GBTClassifier', 'GBTClassificationModel', + 'RandomForestClassifier', 'RandomForestClassificationModel', + 'NaiveBayes', 'NaiveBayesModel', + 'MultilayerPerceptronClassifier', 'MultilayerPerceptronClassificationModel'] @inherit_doc diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 611b9190491c..1cea477acb47 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -21,7 +21,8 @@ from pyspark.ml.param.shared import * from pyspark.mllib.common import inherit_doc -__all__ = ['KMeans', 'KMeansModel', 'BisectingKMeans', 'BisectingKMeansModel'] +__all__ = ['BisectingKMeans', 'BisectingKMeansModel', + 'KMeans', 'KMeansModel'] class KMeansModel(JavaModel, MLWritable, MLReadable): diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 369f3508fda5..fb31c7310c0a 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -27,15 +27,34 @@ from pyspark.mllib.common import inherit_doc from pyspark.mllib.linalg import _convert_to_vector -__all__ = ['Binarizer', 'Bucketizer', 'CountVectorizer', 'CountVectorizerModel', 'DCT', - 'ElementwiseProduct', 'HashingTF', 'IDF', 'IDFModel', 'IndexToString', - 'MaxAbsScaler', 'MaxAbsScalerModel', 'MinMaxScaler', 'MinMaxScalerModel', - 'NGram', 'Normalizer', 'OneHotEncoder', 'PCA', 'PCAModel', 'PolynomialExpansion', - 'QuantileDiscretizer', 'RegexTokenizer', 'RFormula', 'RFormulaModel', - 'SQLTransformer', 'StandardScaler', 'StandardScalerModel', 'StopWordsRemover', - 'StringIndexer', 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', - 'VectorIndexer', 'VectorSlicer', 'Word2Vec', 'Word2VecModel', 'ChiSqSelector', - 'ChiSqSelectorModel'] +__all__ = ['Binarizer', + 'Bucketizer', + 'ChiSqSelector', 'ChiSqSelectorModel', + 'CountVectorizer', 'CountVectorizerModel', + 'DCT', + 'ElementwiseProduct', + 'HashingTF', + 'IDF', 'IDFModel', + 'IndexToString', + 'MaxAbsScaler', 'MaxAbsScalerModel', + 'MinMaxScaler', 'MinMaxScalerModel', + 'NGram', + 'Normalizer', + 'OneHotEncoder', + 'PCA', 'PCAModel', + 'PolynomialExpansion', + 'QuantileDiscretizer', + 'RegexTokenizer', + 'RFormula', 'RFormulaModel', + 'SQLTransformer', + 'StandardScaler', 'StandardScalerModel', + 'StopWordsRemover', + 'StringIndexer', 'StringIndexerModel', + 'Tokenizer', + 'VectorAssembler', + 'VectorIndexer', 'VectorIndexerModel', + 'VectorSlicer', + 'Word2Vec', 'Word2VecModel'] @inherit_doc From b4d096ded6540c46a2b07b0b0897cbb0b43ba1e0 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Tue, 1 Mar 2016 21:28:30 -0800 Subject: [PATCH 25/28] [BUILD][MINOR] Fix SBT build error with network-yarn module ## What changes were proposed in this pull request? ``` error] Expected ID character [error] Not a valid command: common (similar: completions) [error] Expected project ID [error] Expected configuration [error] Expected ':' (if selecting a configuration) [error] Expected key [error] Not a valid key: common (similar: commands) [error] common/network-yarn/test ``` `common/network-yarn` is not a valid sbt project, we should change to `network-yarn`. ## How was this patch tested? Locally run the the unit-test. CC rxin , we should either change here, or change the sbt project name. Author: jerryshao Closes #11456 from jerryshao/build-fix. --- dev/sparktestsupport/modules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 4e04672ad39e..e4f2edaf9511 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -477,7 +477,7 @@ def __hash__(self): ], sbt_test_goals=[ "yarn/test", - "common/network-yarn/test", + "network-yarn/test", ], test_tags=[ "org.apache.spark.tags.ExtendedYarnTest" From 366f26d2da0437aab99fd88b70ca12ae18958451 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 2 Mar 2016 11:48:23 +0000 Subject: [PATCH 26/28] [MINOR][STREAMING] Replace deprecated `apply` with `create` in example. ## What changes were proposed in this pull request? Twitter Algebird deprecated `apply` in HyperLogLog.scala. ``` deprecated("Use toHLL", since = "0.10.0 / 2015-05") def apply[T <% Array[Byte]](t: T) = create(t) ``` This PR replace the deprecated usage `apply` with new `create` according to the upstream change. ## How was this patch tested? manual. ``` /bin/spark-submit --class org.apache.spark.examples.streaming.TwitterAlgebirdHLL examples/target/scala-2.11/spark-examples-2.0.0-SNAPSHOT-hadoop2.2.0.jar ``` Author: Dongjoon Hyun Closes #11451 from dongjoon-hyun/replace_deprecated_hll_apply. --- .../apache/spark/examples/streaming/TwitterAlgebirdHLL.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala index 0ec6214fdef1..6442b2a4e294 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala @@ -62,7 +62,7 @@ object TwitterAlgebirdHLL { var userSet: Set[Long] = Set() val approxUsers = users.mapPartitions(ids => { - ids.map(id => hll(id)) + ids.map(id => hll.create(id)) }).reduce(_ + _) val exactUsers = users.map(id => Set(id)).reduce(_ ++ _) From 75e618def1116f7413be11d147fcd02829caba08 Mon Sep 17 00:00:00 2001 From: Wojciech Jurczyk Date: Wed, 2 Mar 2016 15:32:32 +0000 Subject: [PATCH 27/28] Fix run-tests.py typos ## What changes were proposed in this pull request? The PR fixes typos in an error message in dev/run-tests.py. Author: Wojciech Jurczyk Closes #11467 from wjur/wjur/typos_run_tests. --- dev/run-tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/run-tests.py b/dev/run-tests.py index 6febbf108900..b65d1a309cb4 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -488,7 +488,7 @@ def main(): if which("R"): run_cmd([os.path.join(SPARK_HOME, "R", "install-dev.sh")]) else: - print("Can't install SparkR as R is was not found in PATH") + print("Cannot install SparkR as R was not found in PATH") if os.environ.get("AMPLAB_JENKINS"): # if we're on the Amplab Jenkins build servers setup variables From d8afd45f8949e0914ce4bd56d832b1158e3c9220 Mon Sep 17 00:00:00 2001 From: lgieron Date: Wed, 2 Mar 2016 15:57:27 +0000 Subject: [PATCH 28/28] [SPARK-13515] Make FormatNumber work irrespective of locale. ## What changes were proposed in this pull request? Change in class FormatNumber to make it work irrespective of locale. ## How was this patch tested? Unit tests. Author: lgieron Closes #11396 from lgieron/SPARK-13515_Fix_Format_Number. --- .../expressions/stringExpressions.scala | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 4be065b30a21..3ee19cc4ad71 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import java.text.DecimalFormat +import java.text.{DecimalFormat, DecimalFormatSymbols} import java.util.{HashMap, Locale, Map => JMap} import org.apache.spark.sql.catalyst.InternalRow @@ -938,8 +938,10 @@ case class FormatNumber(x: Expression, d: Expression) @transient private val pattern: StringBuffer = new StringBuffer() + // SPARK-13515: US Locale configures the DecimalFormat object to use a dot ('.') + // as a decimal separator. @transient - private val numberFormat: DecimalFormat = new DecimalFormat("") + private val numberFormat = new DecimalFormat("", new DecimalFormatSymbols(Locale.US)) override protected def nullSafeEval(xObject: Any, dObject: Any): Any = { val dValue = dObject.asInstanceOf[Int] @@ -962,10 +964,9 @@ case class FormatNumber(x: Expression, d: Expression) pattern.append("0") } } - val dFormat = new DecimalFormat(pattern.toString) lastDValue = dValue - numberFormat.applyPattern(dFormat.toPattern) + numberFormat.applyLocalizedPattern(pattern.toString) } x.dataType match { @@ -992,6 +993,11 @@ case class FormatNumber(x: Expression, d: Expression) val sb = classOf[StringBuffer].getName val df = classOf[DecimalFormat].getName + val dfs = classOf[DecimalFormatSymbols].getName + val l = classOf[Locale].getName + // SPARK-13515: US Locale configures the DecimalFormat object to use a dot ('.') + // as a decimal separator. + val usLocale = "US" val lastDValue = ctx.freshName("lastDValue") val pattern = ctx.freshName("pattern") val numberFormat = ctx.freshName("numberFormat") @@ -999,7 +1005,8 @@ case class FormatNumber(x: Expression, d: Expression) val dFormat = ctx.freshName("dFormat") ctx.addMutableState("int", lastDValue, s"$lastDValue = -100;") ctx.addMutableState(sb, pattern, s"$pattern = new $sb();") - ctx.addMutableState(df, numberFormat, s"""$numberFormat = new $df("");""") + ctx.addMutableState(df, numberFormat, + s"""$numberFormat = new $df("", new $dfs($l.$usLocale));""") s""" if ($d >= 0) { @@ -1013,9 +1020,8 @@ case class FormatNumber(x: Expression, d: Expression) $pattern.append("0"); } } - $df $dFormat = new $df($pattern.toString()); $lastDValue = $d; - $numberFormat.applyPattern($dFormat.toPattern()); + $numberFormat.applyLocalizedPattern($pattern.toString()); } ${ev.value} = UTF8String.fromString($numberFormat.format(${typeHelper(num)})); } else {