From 1c3d2216380c9cc89ea829588305b5f31c71d6d5 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Fri, 29 Apr 2016 10:42:52 -0700 Subject: [PATCH 1/3] Rebase with master. --- .../org/apache/spark/ml/clustering/LDA.scala | 7 +- python/pyspark/ml/clustering.py | 488 +++++++++++++++++- python/pyspark/ml/tests.py | 57 +- 3 files changed, 546 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 1554d568af615..38ecc5a102c12 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -355,7 +355,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM * :: Experimental :: * Model fitted by [[LDA]]. * - * @param vocabSize Vocabulary size (number of terms or terms in the vocabulary) + * @param vocabSize Vocabulary size (number of terms or words in the vocabulary) * @param sparkSession Used to construct local DataFrames for returning query results */ @Since("1.6.0") @@ -745,9 +745,8 @@ object DistributedLDAModel extends MLReadable[DistributedLDAModel] { * - "topic": multinomial distribution over terms representing some concept * - "document": one piece of text, corresponding to one row in the input data * - * References: - * - Original LDA paper (journal version): - * Blei, Ng, and Jordan. "Latent Dirichlet Allocation." JMLR, 2003. + * Original LDA paper (journal version): + * Blei, Ng, and Jordan. "Latent Dirichlet Allocation." JMLR, 2003. * * Input data (featuresCol): * LDA is given a collection of documents as input data, via the featuresCol parameter. diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 16ce02ee7d71a..50ebf4fde1cf5 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -23,7 +23,8 @@ __all__ = ['BisectingKMeans', 'BisectingKMeansModel', 'KMeans', 'KMeansModel', - 'GaussianMixture', 'GaussianMixtureModel'] + 'GaussianMixture', 'GaussianMixtureModel', + 'LDA', 'LDAModel', 'LocalLDAModel', 'DistributedLDAModel'] class GaussianMixtureModel(JavaModel, JavaMLWritable, JavaMLReadable): @@ -450,6 +451,491 @@ def _create_model(self, java_model): return BisectingKMeansModel(java_model) +@inherit_doc +class LDAModel(JavaModel): + """ + .. note:: Experimental + + Latent Dirichlet Allocation (LDA) model. + This abstraction permits for different underlying representations, + including local and distributed data structures. + + .. versionadded:: 2.0.0 + """ + + @since("2.0.0") + def isDistributed(self): + """ + Indicates whether this instance is of type DistributedLDAModel + """ + return self._call_java("isDistributed") + + @since("2.0.0") + def vocabSize(self): + """Vocabulary size (number of terms or words in the vocabulary)""" + return self._call_java("vocabSize") + + @since("2.0.0") + def topicsMatrix(self): + """ + Inferred topics, where each topic is represented by a distribution over terms. + This is a matrix of size vocabSize x k, where each column is a topic. + No guarantees are given about the ordering of the topics. + + WARNING: If this model is actually a :py:class:`DistributedLDAModel` instance produced by + the Expectation-Maximization ("em") `optimizer`, then this method could involve + collecting a large amount of data to the driver (on the order of vocabSize x k). + """ + return self._call_java("topicsMatrix") + + @since("2.0.0") + def logLikelihood(self, dataset): + """ + Calculates a lower bound on the log likelihood of the entire corpus. + See Equation (16) in the Online LDA paper (Hoffman et al., 2010). + + WARNING: If this model is an instance of :py:class:`DistributedLDAModel` (produced when + :py:attr:`optimizer` is set to "em"), this involves collecting a large + :py:func:`topicsMatrix` to the driver. This implementation may be changed in the future. + """ + return self._call_java("logLikelihood", dataset) + + @since("2.0.0") + def logPerplexity(self, dataset): + """ + Calculate an upper bound bound on perplexity. (Lower is better.) + See Equation (16) in the Online LDA paper (Hoffman et al., 2010). + + WARNING: If this model is an instance of :py:class:`DistributedLDAModel` (produced when + :py:attr:`optimizer` is set to "em"), this involves collecting a large + :py:func:`topicsMatrix` to the driver. This implementation may be changed in the future. + """ + return self._call_java("logPerplexity", dataset) + + @since("2.0.0") + def describeTopics(self, maxTermsPerTopic=10): + """ + Return the topics described by their top-weighted terms. + """ + return self._call_java("describeTopics", maxTermsPerTopic) + + @since("2.0.0") + def estimatedDocConcentration(self): + """ + Value for :py:attr:`LDA.docConcentration` estimated from data. + If Online LDA was used and :py:attr::`LDA.optimizeDocConcentration` was set to false, + then this returns the fixed (given) value for the :py:attr:`LDA.docConcentration` parameter. + """ + return self._call_java("estimatedDocConcentration") + + +@inherit_doc +class DistributedLDAModel(LDAModel, JavaMLReadable, JavaMLWritable): + """ + .. note:: Experimental + + Distributed model fitted by :py:class:`LDA`. + This type of model is currently only produced by Expectation-Maximization (EM). + + This model stores the inferred topics, the full training dataset, and the topic distribution + for each training document. + + .. versionadded:: 2.0.0 + """ + + @since("2.0.0") + def toLocal(self): + """ + Convert this distributed model to a local representation. This discards info about the + training dataset. + + WARNING: This involves collecting a large :py:func:`topicsMatrix` to the driver. + """ + return LocalLDAModel(self._call_java("toLocal")) + + @since("2.0.0") + def trainingLogLikelihood(self): + """ + Log likelihood of the observed tokens in the training set, + given the current parameter estimates: + log P(docs | topics, topic distributions for docs, Dirichlet hyperparameters) + + Notes: + - This excludes the prior; for that, use :py:func:`logPrior`. + - Even with :py:func:`logPrior`, this is NOT the same as the data log likelihood given + the hyperparameters. + - This is computed from the topic distributions computed during training. If you call + :py:func:`logLikelihood` on the same training dataset, the topic distributions + will be computed again, possibly giving different results. + """ + return self._call_java("trainingLogLikelihood") + + @since("2.0.0") + def logPrior(self): + """ + Log probability of the current parameter estimate: + log P(topics, topic distributions for docs | alpha, eta) + """ + return self._call_java("logPrior") + + @since("2.0.0") + def getCheckpointFiles(self): + """ + If using checkpointing and :py:attr:`LDA.keepLastCheckpoint` is set to true, then there may + be saved checkpoint files. This method is provided so that users can manage those files. + + Note that removing the checkpoints can cause failures if a partition is lost and is needed + by certain :py:class:`DistributedLDAModel` methods. Reference counting will clean up the + checkpoints when this model and derivative data go out of scope. + + :return List of checkpoint files from training + """ + return self._call_java("getCheckpointFiles") + + +@inherit_doc +class LocalLDAModel(LDAModel, JavaMLReadable, JavaMLWritable): + """ + .. note:: Experimental + + Local (non-distributed) model fitted by :py:class:`LDA`. + This model stores the inferred topics only; it does not store info about the training dataset. + + .. versionadded:: 2.0.0 + """ + pass + + +@inherit_doc +class LDA(JavaEstimator, HasFeaturesCol, HasMaxIter, HasSeed, HasCheckpointInterval, + JavaMLReadable, JavaMLWritable): + """ + .. note:: Experimental + + Latent Dirichlet Allocation (LDA), a topic model designed for text documents. + + Terminology: + + - "term" = "word": an el + - "token": instance of a term appearing in a document + - "topic": multinomial distribution over terms representing some concept + - "document": one piece of text, corresponding to one row in the input data + + Original LDA paper (journal version): + Blei, Ng, and Jordan. "Latent Dirichlet Allocation." JMLR, 2003. + + Input data (featuresCol): + LDA is given a collection of documents as input data, via the featuresCol parameter. + Each document is specified as a :py:class:`Vector` of length vocabSize, where each entry is the + count for the corresponding term (word) in the document. Feature transformers such as + :py:class:`pyspark.ml.feature.Tokenizer` and :py:class:`pyspark.ml.feature.CountVectorizer` + can be useful for converting text to word count vectors. + + >>> from pyspark.mllib.linalg import Vectors, SparseVector + >>> from pyspark.ml.clustering import LDA + >>> df = sqlContext.createDataFrame([[1, Vectors.dense([0.0, 1.0])], + ... [2, SparseVector(2, {0: 1.0})],], ["id", "features"]) + >>> lda = LDA(k=2, seed=1, optimizer="em") + >>> model = lda.fit(df) + >>> model.isDistributed() + True + >>> localModel = model.toLocal() + >>> localModel.isDistributed() + False + >>> model.vocabSize() + 2 + >>> model.describeTopics().show() + +-----+-----------+--------------------+ + |topic|termIndices| termWeights| + +-----+-----------+--------------------+ + | 0| [1, 0]|[0.50401530077160...| + | 1| [0, 1]|[0.50401530077160...| + +-----+-----------+--------------------+ + ... + >>> model.topicsMatrix() + DenseMatrix(2, 2, [0.496, 0.504, 0.504, 0.496], 0) + >>> lda_path = temp_path + "/lda" + >>> lda.save(lda_path) + >>> sameLDA = LDA.load(lda_path) + >>> distributed_model_path = temp_path + "/lda_distributed_model" + >>> model.save(distributed_model_path) + >>> sameModel = DistributedLDAModel.load(distributed_model_path) + >>> local_model_path = temp_path + "/lda_local_model" + >>> localModel.save(local_model_path) + >>> sameLocalModel = LocalLDAModel.load(local_model_path) + + .. versionadded:: 2.0.0 + """ + + k = Param(Params._dummy(), "k", "number of topics (clusters) to infer", + typeConverter=TypeConverters.toInt) + optimizer = Param(Params._dummy(), "optimizer", + "Optimizer or inference algorithm used to estimate the LDA model. " + "Supported: online, em", typeConverter=TypeConverters.toString) + learningOffset = Param(Params._dummy(), "learningOffset", + "A (positive) learning parameter that downweights early iterations." + " Larger values make early iterations count less", + typeConverter=TypeConverters.toFloat) + learningDecay = Param(Params._dummy(), "learningDecay", "Learning rate, set as an" + "exponential decay rate. This should be between (0.5, 1.0] to " + "guarantee asymptotic convergence.", typeConverter=TypeConverters.toFloat) + subsamplingRate = Param(Params._dummy(), "subsamplingRate", + "Fraction of the corpus to be sampled and used in each iteration " + "of mini-batch gradient descent, in range (0, 1].", + typeConverter=TypeConverters.toFloat) + optimizeDocConcentration = Param(Params._dummy(), "optimizeDocConcentration", + "Indicates whether the docConcentration (Dirichlet parameter " + "for document-topic distribution) will be optimized during " + "training.", typeConverter=TypeConverters.toBoolean) + docConcentration = Param(Params._dummy(), "docConcentration", + "Concentration parameter (commonly named \"alpha\") for the " + "prior placed on documents' distributions over topics (\"theta\").", + typeConverter=TypeConverters.toListFloat) + topicConcentration = Param(Params._dummy(), "topicConcentration", + "Concentration parameter (commonly named \"beta\" or \"eta\") for " + "the prior placed on topic' distributions over terms.", + typeConverter=TypeConverters.toFloat) + topicDistributionCol = Param(Params._dummy(), "topicDistributionCol", + "Output column with estimates of the topic mixture distribution " + "for each document (often called \"theta\" in the literature). " + "Returns a vector of zeros for an empty document.", + typeConverter=TypeConverters.toString) + keepLastCheckpoint = Param(Params._dummy(), "keepLastCheckpoint", + "(For EM optimizer) If using checkpointing, this indicates whether" + " to keep the last checkpoint. If false, then the checkpoint will be" + " deleted. Deleting the checkpoint can cause failures if a data" + " partition is lost, so set this bit with care.", + TypeConverters.toBoolean) + + @keyword_only + def __init__(self, featuresCol="features", maxIter=20, seed=None, checkpointInterval=10, + k=10, optimizer="online", learningOffset=1024.0, learningDecay=0.51, + subsamplingRate=0.05, optimizeDocConcentration=True, + docConcentration=None, topicConcentration=None, + topicDistributionCol="topicDistribution", keepLastCheckpoint=True): + """ + __init__(self, featuresCol="features", maxIter=20, seed=None, checkpointInterval=10,\ + k=10, optimizer="online", learningOffset=1024.0, learningDecay=0.51,\ + subsamplingRate=0.05, optimizeDocConcentration=True,\ + docConcentration=None, topicConcentration=None,\ + topicDistributionCol="topicDistribution", keepLastCheckpoint=True): + """ + super(LDA, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.LDA", self.uid) + self._setDefault(maxIter=20, checkpointInterval=10, + k=10, optimizer="online", learningOffset=1024.0, learningDecay=0.51, + subsamplingRate=0.05, optimizeDocConcentration=True, + topicDistributionCol="topicDistribution", keepLastCheckpoint=True) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + def _create_model(self, java_model): + if self.getOptimizer() == "em": + return DistributedLDAModel(java_model) + else: + return LocalLDAModel(java_model) + + @keyword_only + @since("2.0.0") + def setParams(self, featuresCol="features", maxIter=20, seed=None, checkpointInterval=10, + k=10, optimizer="online", learningOffset=1024.0, learningDecay=0.51, + subsamplingRate=0.05, optimizeDocConcentration=True, + docConcentration=None, topicConcentration=None, + topicDistributionCol="topicDistribution", keepLastCheckpoint=True): + """ + setParams(self, featuresCol="features", maxIter=20, seed=None, checkpointInterval=10,\ + k=10, optimizer="online", learningOffset=1024.0, learningDecay=0.51,\ + subsamplingRate=0.05, optimizeDocConcentration=True,\ + docConcentration=None, topicConcentration=None,\ + topicDistributionCol="topicDistribution", keepLastCheckpoint=True): + + Sets params for LDA. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + @since("2.0.0") + def setK(self, value): + """ + Sets the value of :py:attr:`k`. + + >>> algo = LDA().setK(10) + >>> algo.getK() + 10 + """ + return self._set(k=value) + + @since("2.0.0") + def getK(self): + """ + Gets the value of :py:attr:`k` or its default value. + """ + return self.getOrDefault(self.k) + + @since("2.0.0") + def setOptimizer(self, value): + """ + Sets the value of :py:attr:`optimizer`. + Currenlty only support 'em' and 'online'. + + >>> algo = LDA().setOptimizer("em") + >>> algo.getOptimizer() + 'em' + """ + return self._set(optimizer=value) + + @since("2.0.0") + def getOptimizer(self): + """ + Gets the value of :py:attr:`optimizer` or its default value. + """ + return self.getOrDefault(self.optimizer) + + @since("2.0.0") + def setLearningOffset(self, value): + """ + Sets the value of :py:attr:`learningOffset`. + + >>> algo = LDA().setLearningOffset(100) + >>> algo.getLearningOffset() + 100.0 + """ + return self._set(learningOffset=value) + + @since("2.0.0") + def getLearningOffset(self): + """ + Gets the value of :py:attr:`learningOffset` or its default value. + """ + return self.getOrDefault(self.learningOffset) + + @since("2.0.0") + def setLearningDecay(self, value): + """ + Sets the value of :py:attr:`learningDecay`. + + >>> algo = LDA().setLearningDecay(0.1) + >>> algo.getLearningDecay() + 0.1... + """ + return self._set(learningDecay=value) + + @since("2.0.0") + def getLearningDecay(self): + """ + Gets the value of :py:attr:`learningDecay` or its default value. + """ + return self.getOrDefault(self.learningDecay) + + @since("2.0.0") + def setSubsamplingRate(self, value): + """ + Sets the value of :py:attr:`subsamplingRate`. + + >>> algo = LDA().setSubsamplingRate(0.1) + >>> algo.getSubsamplingRate() + 0.1... + """ + return self._set(subsamplingRate=value) + + @since("2.0.0") + def getSubsamplingRate(self): + """ + Gets the value of :py:attr:`subsamplingRate` or its default value. + """ + return self.getOrDefault(self.subsamplingRate) + + @since("2.0.0") + def setOptimizeDocConcentration(self, value): + """ + Sets the value of :py:attr:`optimizeDocConcentration`. + + >>> algo = LDA().setOptimizeDocConcentration(True) + >>> algo.getOptimizeDocConcentration() + True + """ + return self._set(optimizeDocConcentration=value) + + @since("2.0.0") + def getOptimizeDocConcentration(self): + """ + Gets the value of :py:attr:`optimizeDocConcentration` or its default value. + """ + return self.getOrDefault(self.optimizeDocConcentration) + + @since("2.0.0") + def setDocConcentration(self, value): + """ + Sets the value of :py:attr:`docConcentration`. + + >>> algo = LDA().setDocConcentration([0.1, 0.2]) + >>> algo.getDocConcentration() + [0.1..., 0.2...] + """ + return self._set(docConcentration=value) + + @since("2.0.0") + def getDocConcentration(self): + """ + Gets the value of :py:attr:`docConcentration` or its default value. + """ + return self.getOrDefault(self.docConcentration) + + @since("2.0.0") + def setTopicConcentration(self, value): + """ + Sets the value of :py:attr:`topicConcentration`. + + >>> algo = LDA().setTopicConcentration(0.5) + >>> algo.getTopicConcentration() + 0.5... + """ + return self._set(topicConcentration=value) + + @since("2.0.0") + def getTopicConcentration(self): + """ + Gets the value of :py:attr:`topicConcentration` or its default value. + """ + return self.getOrDefault(self.topicConcentration) + + @since("2.0.0") + def setTopicDistributionCol(self, value): + """ + Sets the value of :py:attr:`topicDistributionCol`. + + >>> algo = LDA().setTopicDistributionCol("topicDistributionCol") + >>> algo.getTopicDistributionCol() + 'topicDistributionCol' + """ + return self._set(topicDistributionCol=value) + + @since("2.0.0") + def getTopicDistributionCol(self): + """ + Gets the value of :py:attr:`topicDistributionCol` or its default value. + """ + return self.getOrDefault(self.topicDistributionCol) + + @since("2.0.0") + def setKeepLastCheckpoint(self, value): + """ + Sets the value of :py:attr:`keepLastCheckpoint`. + + >>> algo = LDA().setKeepLastCheckpoint(False) + >>> algo.getKeepLastCheckpoint() + False + """ + return self._set(keepLastCheckpoint=value) + + @since("2.0.0") + def getKeepLastCheckpoint(self): + """ + Gets the value of :py:attr:`keepLastCheckpoint` or its default value. + """ + return self.getOrDefault(self.keepLastCheckpoint) + + if __name__ == "__main__": import doctest import pyspark.ml.clustering diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 36cecd4682fea..e7d4c0af45983 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -46,7 +46,7 @@ from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer from pyspark.ml.classification import ( LogisticRegression, DecisionTreeClassifier, OneVsRest, OneVsRestModel) -from pyspark.ml.clustering import KMeans +from pyspark.ml.clustering import * from pyspark.ml.evaluation import BinaryClassificationEvaluator, RegressionEvaluator from pyspark.ml.feature import * from pyspark.ml.param import Param, Params, TypeConverters @@ -809,6 +809,61 @@ def test_decisiontree_regressor(self): pass +class LDATest(PySparkTestCase): + + def _compare(self, m1, m2): + """ + Temp method for comparing instances. + TODO: Replace with generic implementation once SPARK-14706 is merged. + """ + self.assertEqual(m1.uid, m2.uid) + self.assertEqual(type(m1), type(m2)) + self.assertEqual(len(m1.params), len(m2.params)) + for p in m1.params: + if m1.isDefined(p): + self.assertEqual(m1.getOrDefault(p), m2.getOrDefault(p)) + self.assertEqual(p.parent, m2.getParam(p.name).parent) + if isinstance(m1, LDAModel): + self.assertEqual(m1.vocabSize(), m2.vocabSize()) + self.assertEqual(m1.topicsMatrix(), m2.topicsMatrix()) + + def test_persistence(self): + # Test save/load for LDA, LocalLDAModel, DistributedLDAModel. + sqlContext = SQLContext(self.sc) + df = sqlContext.createDataFrame([ + [1, Vectors.dense([0.0, 1.0])], + [2, Vectors.sparse(2, {0: 1.0})], + ], ["id", "features"]) + # Fit model + lda = LDA(k=2, seed=1, optimizer="em") + distributedModel = lda.fit(df) + self.assertTrue(distributedModel.isDistributed()) + localModel = distributedModel.toLocal() + self.assertFalse(localModel.isDistributed()) + # Define paths + path = tempfile.mkdtemp() + lda_path = path + "/lda" + dist_model_path = path + "/distLDAModel" + local_model_path = path + "/localLDAModel" + # Test LDA + lda.save(lda_path) + lda2 = LDA.load(lda_path) + self._compare(lda, lda2) + # Test DistributedLDAModel + distributedModel.save(dist_model_path) + distributedModel2 = DistributedLDAModel.load(dist_model_path) + self._compare(distributedModel, distributedModel2) + # Test LocalLDAModel + localModel.save(local_model_path) + localModel2 = LocalLDAModel.load(local_model_path) + self._compare(localModel, localModel2) + # Clean up + try: + rmtree(path) + except OSError: + pass + + class TrainingSummaryTest(PySparkTestCase): def test_linear_regression_summary(self): From ebe2900aadd3af0114ed71506088c6a736dd5002 Mon Sep 17 00:00:00 2001 From: Arun Allamsetty Date: Thu, 21 Dec 2017 15:52:15 -0700 Subject: [PATCH 2/3] SPARK-17916: Fix empty string being parsed as null when nullValue is set. --- .../datasources/csv/CSVOptions.scala | 2 +- .../execution/datasources/csv/CSVSuite.scala | 45 +++++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index a13a5a34b4a84..2df8370901d7b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -152,7 +152,7 @@ class CSVOptions( writerSettings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceFlagInWrite) writerSettings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceFlagInWrite) writerSettings.setNullValue(nullValue) - writerSettings.setEmptyValue(nullValue) + writerSettings.setEmptyValue("") writerSettings.setSkipEmptyLines(true) writerSettings.setQuoteAllFields(quoteAll) writerSettings.setQuoteEscapingEnabled(escapeQuotes) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 4fe45420b4e77..1d9ccd6b76aa4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -1248,4 +1248,49 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { Row("0,2013-111-11 12:13:14") :: Row(null) :: Nil ) } + + test("SPARK-17916: An empty string should not be coerced to null when nullValue is passed.") { + val sparkSession = spark + + val elems = Seq(("bar"), (""), (null: String)) + + // Checks for new behavior where an empty string is not coerced to null. + withTempDir { dir => + val outDir = new File(dir, "out").getCanonicalPath + val nullValue = "\\N" + + import sparkSession.implicits._ + val dsIn = spark.createDataset(elems) + dsIn.write + .option("nullValue", nullValue) + .csv(outDir) + val dsOut = spark.read + .option("nullValue", nullValue) + .schema(dsIn.schema) + .csv(outDir) + .as[(String)] + val computed = dsOut.collect.toSeq + val expected = Seq(("bar"), (null: String)) + + assert(computed.size === 2) + assert(computed.sameElements(expected)) + } + // Keeps the old behavior for when nullValue is not passed. + withTempDir { dir => + val outDir = new File(dir, "out").getCanonicalPath + + import sparkSession.implicits._ + val dsIn = spark.createDataset(elems) + dsIn.write.csv(outDir) + val dsOut = spark.read + .schema(dsIn.schema) + .csv(outDir) + .as[(String)] + val computed = dsOut.collect.toSeq + val expected = Seq(("bar")) + + assert(computed.size === 1) + assert(computed.sameElements(expected)) + } + } } From 156d755d5a734a00c4c69dfc3565364f3843fca1 Mon Sep 17 00:00:00 2001 From: Arun Allamsetty Date: Sun, 24 Dec 2017 23:41:40 -0700 Subject: [PATCH 3/3] Incorporate code review suggestions. --- .../datasources/csv/CSVOptions.scala | 4 + .../execution/datasources/csv/CSVSuite.scala | 81 ++++++++++--------- 2 files changed, 45 insertions(+), 40 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 2df8370901d7b..29e76875ad9a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -152,6 +152,10 @@ class CSVOptions( writerSettings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceFlagInWrite) writerSettings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceFlagInWrite) writerSettings.setNullValue(nullValue) + // The Univocity parser parses empty strings as `null` by default. This is the default behavior + // for Spark too, since `nullValue` defaults to an empty string and has a higher precedence to + // setEmptyValue(). But when `nullValue` is set to a different value, that would mean that the + // empty string should be parsed not as `null` but as an empty string. writerSettings.setEmptyValue("") writerSettings.setSkipEmptyLines(true) writerSettings.setQuoteAllFields(quoteAll) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 1d9ccd6b76aa4..140ac9b52e640 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -1250,47 +1250,48 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("SPARK-17916: An empty string should not be coerced to null when nullValue is passed.") { - val sparkSession = spark - - val elems = Seq(("bar"), (""), (null: String)) - - // Checks for new behavior where an empty string is not coerced to null. - withTempDir { dir => - val outDir = new File(dir, "out").getCanonicalPath - val nullValue = "\\N" - - import sparkSession.implicits._ - val dsIn = spark.createDataset(elems) - dsIn.write - .option("nullValue", nullValue) - .csv(outDir) - val dsOut = spark.read - .option("nullValue", nullValue) - .schema(dsIn.schema) - .csv(outDir) - .as[(String)] - val computed = dsOut.collect.toSeq - val expected = Seq(("bar"), (null: String)) - - assert(computed.size === 2) - assert(computed.sameElements(expected)) + val litNull: String = null + val df = Seq( + (1, "John Doe"), + (2, ""), + (3, "-"), + (4, litNull) + ).toDF("id", "name") + + // Checks for new behavior where an empty string is not coerced to null when `nullValue` is + // set to anything but an empty string literal. + withTempPath { path => + df.write + .option("nullValue", "-") + .csv(path.getAbsolutePath) + val computed = spark.read + .option("nullValue", "-") + .schema(df.schema) + .csv(path.getAbsolutePath) + val expected = Seq( + (1, "John Doe"), + (2, ""), + (3, litNull), + (4, litNull) + ).toDF("id", "name") + + checkAnswer(computed, expected) } - // Keeps the old behavior for when nullValue is not passed. - withTempDir { dir => - val outDir = new File(dir, "out").getCanonicalPath - - import sparkSession.implicits._ - val dsIn = spark.createDataset(elems) - dsIn.write.csv(outDir) - val dsOut = spark.read - .schema(dsIn.schema) - .csv(outDir) - .as[(String)] - val computed = dsOut.collect.toSeq - val expected = Seq(("bar")) - - assert(computed.size === 1) - assert(computed.sameElements(expected)) + // Keeps the old behavior where empty string us coerced to nullValue is not passed. + withTempPath { path => + df.write + .csv(path.getAbsolutePath) + val computed = spark.read + .schema(df.schema) + .csv(path.getAbsolutePath) + val expected = Seq( + (1, "John Doe"), + (2, litNull), + (3, "-"), + (4, litNull) + ).toDF("id", "name") + + checkAnswer(computed, expected) } } }