diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMOptions.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMOptions.scala index 6900b4153a7eb..8be23894a732f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMOptions.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMOptions.scala @@ -41,6 +41,9 @@ private[libsvm] class LibSVMOptions(@transient private val parameters: CaseInsen case o => throw new IllegalArgumentException(s"Invalid value `$o` for parameter " + s"`$VECTOR_TYPE`. Expected types are `sparse` and `dense`.") } + + val lineSeparator: String = parameters.getOrElse(LINE_SEPARATOR, "\n") + require(lineSeparator.nonEmpty, s"'$LINE_SEPARATOR' cannot be an empty string.") } private[libsvm] object LibSVMOptions { @@ -48,4 +51,5 @@ private[libsvm] object LibSVMOptions { val VECTOR_TYPE = "vectorType" val DENSE_VECTOR_TYPE = "dense" val SPARSE_VECTOR_TYPE = "sparse" + val LINE_SEPARATOR = "lineSep" } diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index 4e84ff044f55e..4c928ce4a753e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -41,6 +41,7 @@ import org.apache.spark.util.SerializableConfiguration private[libsvm] class LibSVMOutputWriter( path: String, dataSchema: StructType, + lineSeparator: String, context: TaskAttemptContext) extends OutputWriter { @@ -57,7 +58,7 @@ private[libsvm] class LibSVMOutputWriter( writer.write(s" ${i + 1}:$v") } - writer.write('\n') + writer.write(lineSeparator) } override def close(): Unit = { @@ -100,7 +101,7 @@ private[libsvm] class LibSVMFileFormat "'numFeatures' option to avoid the extra scan.") val paths = files.map(_.getPath.toUri.toString) - val parsed = MLUtils.parseLibSVMFile(sparkSession, paths) + val parsed = MLUtils.parseLibSVMFile(sparkSession, paths, libSVMOptions.lineSeparator) MLUtils.computeNumFeatures(parsed) } @@ -120,12 +121,13 @@ private[libsvm] class LibSVMFileFormat options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { verifySchema(dataSchema, true) + val libSVMOptions = new LibSVMOptions(options) new OutputWriterFactory { override def newInstance( path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new LibSVMOutputWriter(path, dataSchema, context) + new LibSVMOutputWriter(path, dataSchema, libSVMOptions.lineSeparator, context) } override def getFileExtension(context: TaskAttemptContext): String = { @@ -153,7 +155,8 @@ private[libsvm] class LibSVMFileFormat sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) (file: PartitionedFile) => { - val linesReader = new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value) + val linesReader = new HadoopFileLinesReader( + file, libSVMOptions.lineSeparator, broadcastedHadoopConf.value.value) Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) val points = linesReader diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 14af8b5c73870..187bd4555de16 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -30,7 +30,7 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD} import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.execution.datasources.text.TextFileFormat +import org.apache.spark.sql.execution.datasources.text.{TextFileFormat, TextOptions} import org.apache.spark.sql.functions._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.random.BernoulliCellSampler @@ -105,12 +105,17 @@ object MLUtils extends Logging { } private[spark] def parseLibSVMFile( - sparkSession: SparkSession, paths: Seq[String]): RDD[(Double, Array[Int], Array[Double])] = { + sparkSession: SparkSession, + paths: Seq[String], + lineSeparator: String): RDD[(Double, Array[Int], Array[Double])] = { + val textOptions = Map(TextOptions.LINE_SEPARATOR -> lineSeparator) + val lines = sparkSession.baseRelationToDataFrame( DataSource.apply( sparkSession, paths = paths, - className = classOf[TextFileFormat].getName + className = classOf[TextFileFormat].getName, + options = textOptions ).resolveRelation(checkFilesExist = false)) .select("value") diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala index 3eabff434e8de..3c053cac1f181 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala @@ -184,4 +184,62 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { spark.sql("DROP TABLE IF EXISTS libsvmTable") } } + + def testLineSeparator(lineSep: String): Unit = { + test(s"SPARK-21289: Support line separator - lineSep: '$lineSep'") { + val data = Seq( + "1.0 1:1.0 3:2.0 5:3.0", "0.0", "0.0", "0.0 2:4.0 4:5.0 6:6.0").mkString(lineSep) + val dataWithTrailingLineSep = s"$data$lineSep" + + Seq(data, dataWithTrailingLineSep).foreach { lines => + val path0 = new File(tempDir.getCanonicalPath, "write0") + val path1 = new File(tempDir.getCanonicalPath, "write1") + try { + // Read + java.nio.file.Files.write(path0.toPath, lines.getBytes(StandardCharsets.UTF_8)) + val df = spark.read + .option("lineSep", lineSep) + .format("libsvm") + .load(path0.getAbsolutePath) + + assert(df.columns(0) == "label") + assert(df.columns(1) == "features") + + val results = df.collect() + + assert(results.map(_.getDouble(0)).toSet == Seq(1.0, 0.0, 0.0, 0.0).toSet) + + val actual = results.map(_.getAs[SparseVector](1)) + val expected = Seq( + Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0))), + Vectors.sparse(6, Nil), + Vectors.sparse(6, Nil), + Vectors.sparse(6, Seq((1, 4.0), (3, 5.0), (5, 6.0)))) + assert(actual.toSet == expected.toSet) + + // Write + df.coalesce(1) + .write.option("lineSep", lineSep).format("libsvm").save(path1.getAbsolutePath) + val partFile = Utils.recursiveList(path1).filter(f => f.getName.startsWith("part-")).head + val readBack = new String( + java.nio.file.Files.readAllBytes(partFile.toPath), StandardCharsets.UTF_8) + assert(readBack == dataWithTrailingLineSep) + + // Roundtrip + val readBackDF = spark.read + .option("lineSep", lineSep) + .format("libsvm") + .load(path1.getAbsolutePath) + assert(df.collect().toSet == readBackDF.collect().toSet) + } finally { + Utils.deleteRecursively(path0) + Utils.deleteRecursively(path1) + } + } + } + } + + Seq("123!!@", "^", "&@").foreach { lineSep => + testLineSeparator(lineSep) + } } diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index f3092918abb54..6173e9e45ffff 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -176,7 +176,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, - multiLine=None, allowUnquotedControlChars=None): + multiLine=None, allowUnquotedControlChars=None, lineSep=None): """ Loads JSON files and returns the results as a :class:`DataFrame`. @@ -237,6 +237,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param allowUnquotedControlChars: allows JSON Strings to contain unquoted control characters (ASCII characters with value less than 32, including tab and line feed characters) or not. + :param lineSep: defines the line separator that should be used for parsing. If None is + set, it uses ``\\n`` by default, covering ``\\r``, ``\\r\\n`` and ``\\n``. >>> df1 = spark.read.json('python/test_support/sql/people.json') >>> df1.dtypes @@ -254,7 +256,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, timestampFormat=timestampFormat, multiLine=multiLine, - allowUnquotedControlChars=allowUnquotedControlChars) + allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep) if isinstance(path, basestring): path = [path] if type(path) == list: @@ -304,7 +306,7 @@ def parquet(self, *paths): @ignore_unicode_prefix @since(1.6) - def text(self, paths): + def text(self, paths, lineSep=None): """ Loads text files and returns a :class:`DataFrame` whose schema starts with a string column named "value", and followed by partitioned columns if there @@ -313,11 +315,14 @@ def text(self, paths): Each line in the text file is a new row in the resulting DataFrame. :param paths: string, or list of strings, for input path(s). + :param lineSep: defines the line separator that should be used for parsing. If None is + set, it uses ``\\n`` by default, covering ``\\r``, ``\\r\\n`` and ``\\n``. >>> df = spark.read.text('python/test_support/sql/text-test.txt') >>> df.collect() [Row(value=u'hello'), Row(value=u'this')] """ + self._set_opts(lineSep=lineSep) if isinstance(paths, basestring): paths = [paths] return self._df(self._jreader.text(self._spark._sc._jvm.PythonUtils.toSeq(paths))) @@ -342,7 +347,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non :param sep: sets the single character as a separator for each field and value. If None is set, it uses the default value, ``,``. :param encoding: decodes the CSV files by the given encoding type. If None is set, - it uses the default value, ``UTF-8``. + it uses the default value, ``UTF-8``. Note that, currently this option + does not support non-ascii compatible encodings. :param quote: sets the single character used for escaping quoted values where the separator can be part of the value. If None is set, it uses the default value, ``"``. If you would like to turn off quotations, you need to set an @@ -730,7 +736,8 @@ def saveAsTable(self, name, format=None, mode=None, partitionBy=None, **options) self._jwrite.saveAsTable(name) @since(1.4) - def json(self, path, mode=None, compression=None, dateFormat=None, timestampFormat=None): + def json(self, path, mode=None, compression=None, dateFormat=None, timestampFormat=None, + lineSep=None): """Saves the content of the :class:`DataFrame` in JSON format (`JSON Lines text format or newline-delimited JSON `_) at the specified path. @@ -753,12 +760,15 @@ def json(self, path, mode=None, compression=None, dateFormat=None, timestampForm formats follow the formats at ``java.text.SimpleDateFormat``. This applies to timestamp type. If None is set, it uses the default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``. + :param lineSep: defines the line separator that should be used for writing. If None is + set, it uses the default value, ``\\n``. >>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data')) """ self.mode(mode) self._set_opts( - compression=compression, dateFormat=dateFormat, timestampFormat=timestampFormat) + compression=compression, dateFormat=dateFormat, timestampFormat=timestampFormat, + lineSep=lineSep) self._jwrite.json(path) @since(1.4) @@ -788,18 +798,20 @@ def parquet(self, path, mode=None, partitionBy=None, compression=None): self._jwrite.parquet(path) @since(1.6) - def text(self, path, compression=None): + def text(self, path, compression=None, lineSep=None): """Saves the content of the DataFrame in a text file at the specified path. :param path: the path in any Hadoop supported file system :param compression: compression codec to use when saving to file. This can be one of the known case-insensitive shorten names (none, bzip2, gzip, lz4, snappy and deflate). + :param lineSep: defines the line separator that should be used for writing. If None is + set, it uses the default value, ``\\n``. The DataFrame must have only one column that is of string type. Each row becomes a new line in the output file. """ - self._set_opts(compression=compression) + self._set_opts(compression=compression, lineSep=lineSep) self._jwrite.text(path) @since(2.0) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 0cf702143c773..adce8b7c434a2 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -407,7 +407,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, - multiLine=None, allowUnquotedControlChars=None): + multiLine=None, allowUnquotedControlChars=None, lineSep=None): """ Loads a JSON file stream and returns the results as a :class:`DataFrame`. @@ -470,6 +470,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param allowUnquotedControlChars: allows JSON Strings to contain unquoted control characters (ASCII characters with value less than 32, including tab and line feed characters) or not. + :param lineSep: defines the line separator that should be used for parsing. If None is + set, it uses ``\\n`` by default, covering ``\\r``, ``\\r\\n`` and ``\\n``. >>> json_sdf = spark.readStream.json(tempfile.mkdtemp(), schema = sdf_schema) >>> json_sdf.isStreaming @@ -484,7 +486,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, timestampFormat=timestampFormat, multiLine=multiLine, - allowUnquotedControlChars=allowUnquotedControlChars) + allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep) if isinstance(path, basestring): return self._df(self._jreader.json(path)) else: @@ -514,7 +516,7 @@ def parquet(self, path): @ignore_unicode_prefix @since(2.0) - def text(self, path): + def text(self, path, lineSep=None): """ Loads a text file stream and returns a :class:`DataFrame` whose schema starts with a string column named "value", and followed by partitioned columns if there @@ -525,6 +527,8 @@ def text(self, path): .. note:: Evolving. :param paths: string, or list of strings, for input path(s). + :param lineSep: defines the line separator that should be used for parsing. If None is + set, it uses ``\\n`` by default, covering ``\\r``, ``\\r\\n`` and ``\\n``. >>> text_sdf = spark.readStream.text(tempfile.mkdtemp()) >>> text_sdf.isStreaming @@ -532,6 +536,7 @@ def text(self, path): >>> "value" in str(text_sdf.schema) True """ + self._set_opts(lineSep=lineSep) if isinstance(path, basestring): return self._df(self._jreader.text(path)) else: @@ -558,7 +563,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non :param sep: sets the single character as a separator for each field and value. If None is set, it uses the default value, ``,``. :param encoding: decodes the CSV files by the given encoding type. If None is set, - it uses the default value, ``UTF-8``. + it uses the default value, ``UTF-8``. Note that, currently this option + does not support non-ascii compatible encodings. :param quote: sets the single character used for escaping quoted values where the separator can be part of the value. If None is set, it uses the default value, ``"``. If you would like to turn off quotations, you need to set an diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 98afae662b42d..8fa0f0c2360c7 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -508,12 +508,51 @@ def test_non_existed_udaf(self): self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udaf", lambda: spark.udf.registerJavaUDAF("udaf1", "non_existed_udaf")) - def test_multiLine_json(self): + def test_linesep_text(self): + df = self.spark.read.text("python/test_support/sql/ages_newlines.csv", lineSep=",") + expected = [Row(value=u'Joe'), Row(value=u'20'), Row(value=u'"Hi'), + Row(value=u'\nI am Jeo"\nTom'), Row(value=u'30'), + Row(value=u'"My name is Tom"\nHyukjin'), Row(value=u'25'), + Row(value=u'"I am Hyukjin\n\nI love Spark!"\n')] + self.assertEqual(df.collect(), expected) + + tpath = tempfile.mkdtemp() + shutil.rmtree(tpath) + try: + df.write.text(tpath, lineSep="!") + expected = [Row(value=u'Joe!20!"Hi!'), Row(value=u'I am Jeo"'), + Row(value=u'Tom!30!"My name is Tom"'), + Row(value=u'Hyukjin!25!"I am Hyukjin'), + Row(value=u''), Row(value=u'I love Spark!"'), + Row(value=u'!')] + readback = self.spark.read.text(tpath) + self.assertEqual(readback.collect(), expected) + finally: + shutil.rmtree(tpath) + + def test_multiline_json(self): people1 = self.spark.read.json("python/test_support/sql/people.json") people_array = self.spark.read.json("python/test_support/sql/people_array.json", multiLine=True) self.assertEqual(people1.collect(), people_array.collect()) + def test_linesep_json(self): + df = self.spark.read.json("python/test_support/sql/people.json", lineSep=",") + expected = [Row(_corrupt_record=None, name=u'Michael'), + Row(_corrupt_record=u' "age":30}\n{"name":"Justin"', name=None), + Row(_corrupt_record=u' "age":19}\n', name=None)] + self.assertEqual(df.collect(), expected) + + tpath = tempfile.mkdtemp() + shutil.rmtree(tpath) + try: + df = self.spark.read.json("python/test_support/sql/people.json") + df.write.json(tpath, lineSep="!!") + readback = self.spark.read.json(tpath, lineSep="!!") + self.assertEqual(readback.collect(), df.collect()) + finally: + shutil.rmtree(tpath) + def test_multiline_csv(self): ages_newlines = self.spark.read.csv( "python/test_support/sql/ages_newlines.csv", multiLine=True) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 652412b34478a..e36d5ada9398f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -85,6 +85,9 @@ private[sql] class JSONOptions( val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false) + val lineSeparator: String = parameters.getOrElse("lineSep", "\n") + require(lineSeparator.nonEmpty, "'lineSep' cannot be an empty string.") + /** Sets config options on a Jackson [[JsonFactory]]. */ def setJacksonOptions(factory: JsonFactory): Unit = { factory.configure(JsonParser.Feature.ALLOW_COMMENTS, allowComments) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala index eb06e4f304f0a..ca7097f2496ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.json import java.io.Writer +import java.nio.charset.StandardCharsets import com.fasterxml.jackson.core._ @@ -74,6 +75,8 @@ private[sql] class JacksonGenerator( private val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null) + private val lineSep = options.lineSeparator + private def makeWriter(dataType: DataType): ValueWriter = dataType match { case NullType => (row: SpecializedGetters, ordinal: Int) => @@ -251,5 +254,5 @@ private[sql] class JacksonGenerator( mapType = dataType.asInstanceOf[MapType])) } - def writeLineEnding(): Unit = gen.writeRaw('\n') + def writeLineEnding(): Unit = gen.writeRaw(lineSep) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 17966eecfc051..ae87701970750 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -370,6 +370,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * `java.text.SimpleDateFormat`. This applies to timestamp type. *
  • `multiLine` (default `false`): parse one record, which may span multiple lines, * per file
  • + *
  • `lineSep` (default is `\n`, covering `\r`, `\r\n` and `\n`): defines the line separator + * that should be used for parsing.
  • * * * @since 2.0.0 @@ -515,7 +517,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `sep` (default `,`): sets the single character as a separator for each * field and value.
  • *
  • `encoding` (default `UTF-8`): decodes the CSV files by the given encoding - * type.
  • + * type. Note that, currently this option does not support non-ascii compatible encodings. *
  • `quote` (default `"`): sets the single character used for escaping quoted values where * the separator can be part of the value. If you would like to turn off quotations, you need to * set not `null` but an empty string. This behaviour is different from @@ -655,6 +657,12 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * spark.read().text("/path/to/spark/README.md") * }}} * + * You can set the following text-specific options to deal with text files: + * + * * @param paths input paths * @since 1.6.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 8d95b24c00619..66bd357cf2a6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -503,6 +503,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
  • + *
  • `lineSep` (default is `\n`): defines the line separator that should + * be used for writing.
  • * * * @since 1.4.0 @@ -572,6 +574,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
  • `compression` (default `null`): compression codec to use when saving to file. This can be * one of the known case-insensitive shorten names (`none`, `bzip2`, `gzip`, `lz4`, * `snappy` and `deflate`).
  • + *
  • `lineSep` (default is `\n`): defines the line separator that should + * be used for writing.
  • * * * @since 1.6.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala index 83cf26c63a175..18d261b0f5bab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala @@ -30,9 +30,16 @@ import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl /** * An adaptor from a [[PartitionedFile]] to an [[Iterator]] of [[Text]], which are all of the lines * in that file. + * + * @param file A part (i.e. "block") of a single file that should be read line by line. + * @param lineSeparator A line separator that should be used for each line. If the value is `\n`, + * it covers `\r`, `\r\n` and `\n`. + * @param conf Hadoop configuration */ class HadoopFileLinesReader( - file: PartitionedFile, conf: Configuration) extends Iterator[Text] with Closeable { + file: PartitionedFile, + lineSeparator: String, + conf: Configuration) extends Iterator[Text] with Closeable { private val iterator = { val fileSplit = new FileSplit( new Path(new URI(file.filePath)), @@ -42,7 +49,12 @@ class HadoopFileLinesReader( Array.empty) val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId) - val reader = new LineRecordReader() + val reader = if (lineSeparator != "\n") { + new LineRecordReader(lineSeparator.getBytes("UTF-8")) + } else { + // This behavior follows Hive. `\n` covers `\r`, `\r\n` and `\n`. + new LineRecordReader() + } reader.initialize(fileSplit, hadoopAttemptContext) new RecordReaderIterator(reader) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 2031381dd2e10..84435feda0a8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -32,9 +32,8 @@ import org.apache.spark.input.{PortableDataStream, StreamInputFormat} import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.datasources.text.TextFileFormat +import org.apache.spark.sql.execution.datasources.text.{TextFileFormat, TextOptions} import org.apache.spark.sql.types.StructType /** @@ -129,7 +128,7 @@ object TextInputCSVDataSource extends CSVDataSource { parser: UnivocityParser, schema: StructType): Iterator[InternalRow] = { val lines = { - val linesReader = new HadoopFileLinesReader(file, conf) + val linesReader = new HadoopFileLinesReader(file, parser.options.lineSeparator, conf) Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) linesReader.map { line => new String(line.getBytes, 0, line.getLength, parser.options.charset) @@ -178,13 +177,16 @@ object TextInputCSVDataSource extends CSVDataSource { sparkSession: SparkSession, inputPaths: Seq[FileStatus], options: CSVOptions): Dataset[String] = { + val textOptions = Map(TextOptions.LINE_SEPARATOR -> options.lineSeparator) + val paths = inputPaths.map(_.getPath.toString) if (Charset.forName(options.charset) == StandardCharsets.UTF_8) { sparkSession.baseRelationToDataFrame( DataSource.apply( sparkSession, paths = paths, - className = classOf[TextFileFormat].getName + className = classOf[TextFileFormat].getName, + options = textOptions ).resolveRelation(checkFilesExist = false)) .select("value").as[String](Encoders.STRING) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index a13a5a34b4a84..f4bcaa37a2d57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -138,6 +138,9 @@ class CSVOptions( val quoteAll = getBool("quoteAll", false) + val lineSeparator: String = parameters.getOrElse("lineSep", "\n") + require(lineSeparator.nonEmpty, "'lineSep' cannot be an empty string.") + val inputBufferSize = 128 val isCommentSet = this.comment != '\u0000' @@ -149,6 +152,9 @@ class CSVOptions( format.setQuote(quote) format.setQuoteEscape(escape) format.setComment(comment) + if (lineSeparator != "\n") { + format.setLineSeparator(lineSeparator) + } writerSettings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceFlagInWrite) writerSettings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceFlagInWrite) writerSettings.setNullValue(nullValue) @@ -166,6 +172,9 @@ class CSVOptions( format.setQuote(quote) format.setQuoteEscape(escape) format.setComment(comment) + if (lineSeparator != "\n") { + format.setLineSeparator(lineSeparator) + } settings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceInRead) settings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceInRead) settings.setReadInputOnSeparateThread(false) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 8b7c2709afde1..90c910f1e438f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.{AnalysisException, Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.datasources.text.TextFileFormat +import org.apache.spark.sql.execution.datasources.text.{TextFileFormat, TextOptions} import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -91,7 +91,8 @@ object TextInputJsonDataSource extends JsonDataSource { sparkSession: SparkSession, inputPaths: Seq[FileStatus], parsedOptions: JSONOptions): StructType = { - val json: Dataset[String] = createBaseDataset(sparkSession, inputPaths) + val json: Dataset[String] = createBaseDataset( + sparkSession, inputPaths, parsedOptions.lineSeparator) inferFromDataset(json, parsedOptions) } @@ -103,13 +104,17 @@ object TextInputJsonDataSource extends JsonDataSource { private def createBaseDataset( sparkSession: SparkSession, - inputPaths: Seq[FileStatus]): Dataset[String] = { + inputPaths: Seq[FileStatus], + lineSeparator: String): Dataset[String] = { + val textOptions = Map(TextOptions.LINE_SEPARATOR -> lineSeparator) + val paths = inputPaths.map(_.getPath.toString) sparkSession.baseRelationToDataFrame( DataSource.apply( sparkSession, paths = paths, - className = classOf[TextFileFormat].getName + className = classOf[TextFileFormat].getName, + options = textOptions ).resolveRelation(checkFilesExist = false)) .select("value").as(Encoders.STRING) } @@ -119,7 +124,7 @@ object TextInputJsonDataSource extends JsonDataSource { file: PartitionedFile, parser: JacksonParser, schema: StructType): Iterator[InternalRow] = { - val linesReader = new HadoopFileLinesReader(file, conf) + val linesReader = new HadoopFileLinesReader(file, parser.options.lineSeparator, conf) Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) val safeParser = new FailureSafeParser[Text]( input => parser.parse(input, CreateJacksonParser.text, textToUTF8String), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index d0690445d7672..3ae7850380f04 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.text +import java.nio.charset.StandardCharsets + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} @@ -77,7 +79,7 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new TextOutputWriter(path, dataSchema, context) + new TextOutputWriter(path, dataSchema, textOptions.lineSeparator, context) } override def getFileExtension(context: TaskAttemptContext): String = { @@ -98,11 +100,14 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { requiredSchema.length <= 1, "Text data source only produces a single data column named \"value\".") + val textOptions = new TextOptions(options) + val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) (file: PartitionedFile) => { - val reader = new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value) + val reader = new HadoopFileLinesReader( + file, textOptions.lineSeparator, broadcastedHadoopConf.value.value) Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => reader.close())) if (requiredSchema.isEmpty) { @@ -128,9 +133,12 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { class TextOutputWriter( path: String, dataSchema: StructType, + lineSeparator: String, context: TaskAttemptContext) extends OutputWriter { + private val lineSep = lineSeparator.getBytes(StandardCharsets.UTF_8) + private val writer = CodecStreams.createOutputStream(context, new Path(path)) override def write(row: InternalRow): Unit = { @@ -138,7 +146,7 @@ class TextOutputWriter( val utf8string = row.getUTF8String(0) utf8string.writeTo(writer) } - writer.write('\n') + writer.write(lineSep) } override def close(): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala index 49bd7382f9cf3..afa82e8ee6696 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala @@ -33,8 +33,12 @@ private[text] class TextOptions(@transient private val parameters: CaseInsensiti * Compression codec to use. */ val compressionCodec = parameters.get(COMPRESSION).map(CompressionCodecs.getCodecClassName) + + val lineSeparator: String = parameters.getOrElse(LINE_SEPARATOR, "\n") + require(lineSeparator.nonEmpty, s"'$LINE_SEPARATOR' cannot be an empty string.") } -private[text] object TextOptions { +private[spark] object TextOptions { val COMPRESSION = "compression" + val LINE_SEPARATOR = "lineSep" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index a42e28053a96a..9e6efc2c51589 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -222,6 +222,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * `java.text.SimpleDateFormat`. This applies to timestamp type. *
  • `multiLine` (default `false`): parse one record, which may span multiple lines, * per file
  • + *
  • `lineSep` (default is `\n`, covering `\r`, `\r\n` and `\n`): defines the line separator + * that should be used for parsing.
  • * * * @since 2.0.0 @@ -242,7 +244,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
  • `sep` (default `,`): sets the single character as a separator for each * field and value.
  • *
  • `encoding` (default `UTF-8`): decodes the CSV files by the given encoding - * type.
  • + * type. Note that, currently this option does not support non-ascii compatible encodings. *
  • `quote` (default `"`): sets the single character used for escaping quoted values where * the separator can be part of the value. If you would like to turn off quotations, you need to * set not `null` but an empty string. This behaviour is different form @@ -334,6 +336,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * * * @since 2.0.0 @@ -360,6 +364,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * * * @param path input path diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index e439699605abb..f51161b83e701 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.execution.datasources.csv import java.io.File -import java.nio.charset.UnsupportedCharsetException +import java.nio.charset.{StandardCharsets, UnsupportedCharsetException} +import java.nio.file.Files import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat import java.util.Locale @@ -30,10 +31,10 @@ import org.apache.hadoop.io.compress.GzipCodec import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, UDT} import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.functions.{col, regexp_replace} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { import testImplicits._ @@ -1245,4 +1246,55 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { Row("0,2013-111-11 12:13:14") :: Row(null) :: Nil ) } + + def testLineSeparator(lineSep: String, multiLine: Boolean): Unit = { + test(s"SPARK-21289: Support line separator - lineSep: '$lineSep' and multiLine: $multiLine") { + // Read + val data = Seq("a,b", "1,\"a\nd\"", "1,f").mkString(lineSep) + val dataWithTrailingLineSep = s"$data$lineSep" + + Seq(data, dataWithTrailingLineSep).foreach { lines => + withTempPath { path => + Files.write(path.toPath, lines.getBytes(StandardCharsets.UTF_8)) + val df = spark.read + .option("multiLine", multiLine) + .option("lineSep", lineSep) + .option("header", true) + .option("inferSchema", true) + .csv(path.getAbsolutePath) + + val expectedSchema = + StructType(StructField("a", IntegerType) :: StructField("b", StringType) :: Nil) + checkAnswer(df, Seq((1, "a\nd"), (1, "f")).toDF()) + assert(df.schema === expectedSchema) + } + } + + // Write + withTempPath { path => + Seq("a", "b", "c").toDF().coalesce(1) + .write.option("lineSep", lineSep).csv(path.getAbsolutePath) + val partFile = Utils.recursiveList(path).filter(f => f.getName.startsWith("part-")).head + val readBack = new String(Files.readAllBytes(partFile.toPath), StandardCharsets.UTF_8) + assert(readBack === s"a${lineSep}b${lineSep}c$lineSep") + } + + // Roundtrip + withTempPath { path => + val df = Seq("a", "b", "c").toDF() + df.write.option("lineSep", lineSep).csv(path.getAbsolutePath) + val readBack = spark.read + .option("multiLine", multiLine) + .option("lineSep", lineSep) + .csv(path.getAbsolutePath) + checkAnswer(df, readBack) + } + } + } + + Seq("|", "^", "::", "\r\n").foreach { lineSep => + Seq(true, false).foreach { multiLine => + testLineSeparator(lineSep, multiLine) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 8c8d41ebf115a..d23fcc4d54f1a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.json import java.io.{File, StringWriter} import java.nio.charset.StandardCharsets +import java.nio.file.Files import java.sql.{Date, Timestamp} import java.util.Locale @@ -2063,4 +2064,51 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) } } + + def testLineSeparator(lineSep: String): Unit = { + test(s"SPARK-21289: Support line separator - lineSep: '$lineSep'") { + // Read + val data = + s""" + | {"f": + |"a", "f0": 1}$lineSep{"f": + | + |"c", "f0": 2}$lineSep{"f": "d", "f0": 3} + """.stripMargin + val dataWithTrailingLineSep = s"$data$lineSep" + + Seq(data, dataWithTrailingLineSep).foreach { lines => + withTempPath { path => + Files.write(path.toPath, lines.getBytes(StandardCharsets.UTF_8)) + val df = spark.read.option("lineSep", lineSep).json(path.getAbsolutePath) + val expectedSchema = + StructType(StructField("f", StringType) :: StructField("f0", LongType) :: Nil) + checkAnswer(df, Seq(("a", 1), ("c", 2), ("d", 3)).toDF()) + assert(df.schema === expectedSchema) + } + } + + // Write + withTempPath { path => + Seq("a", "b", "c").toDF("value").coalesce(1) + .write.option("lineSep", lineSep).json(path.getAbsolutePath) + val partFile = Utils.recursiveList(path).filter(f => f.getName.startsWith("part-")).head + val readBack = new String(Files.readAllBytes(partFile.toPath), StandardCharsets.UTF_8) + assert( + readBack === s"""{"value":"a"}$lineSep{"value":"b"}$lineSep{"value":"c"}$lineSep""") + } + + // Roundtrip + withTempPath { path => + val df = Seq("a", "b", "c").toDF() + df.write.option("lineSep", lineSep).json(path.getAbsolutePath) + val readBack = spark.read.option("lineSep", lineSep).json(path.getAbsolutePath) + checkAnswer(df, readBack) + } + } + } + + Seq("|", "^", "::", "!!!@3").foreach { lineSep => + testLineSeparator(lineSep) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala index cb7393cdd2b9d..7cf7a16b5d884 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql.execution.datasources.text import java.io.File +import java.nio.charset.StandardCharsets +import java.nio.file.Files import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.GzipCodec @@ -172,6 +174,43 @@ class TextSuite extends QueryTest with SharedSQLContext { } } + def testLineSeparator(lineSep: String): Unit = { + test(s"SPARK-21289: Support line separator - lineSep: '$lineSep'") { + // Read + val values = Seq("a", "b", "\nc") + val data = values.mkString(lineSep) + val dataWithTrailingLineSep = s"$data$lineSep" + Seq(data, dataWithTrailingLineSep).foreach { lines => + withTempPath { path => + Files.write(path.toPath, lines.getBytes(StandardCharsets.UTF_8)) + val df = spark.read.option("lineSep", lineSep).text(path.getAbsolutePath) + checkAnswer(df, Seq("a", "b", "\nc").toDF()) + } + } + + // Write + withTempPath { path => + values.toDF().coalesce(1) + .write.option("lineSep", lineSep).text(path.getAbsolutePath) + val partFile = Utils.recursiveList(path).filter(f => f.getName.startsWith("part-")).head + val readBack = new String(Files.readAllBytes(partFile.toPath), StandardCharsets.UTF_8) + assert(readBack === s"a${lineSep}b${lineSep}\nc${lineSep}") + } + + // Roundtrip + withTempPath { path => + val df = values.toDF() + df.write.option("lineSep", lineSep).text(path.getAbsolutePath) + val readBack = spark.read.option("lineSep", lineSep).text(path.getAbsolutePath) + checkAnswer(df, readBack) + } + } + } + + Seq("|", "^", "::", "!!!@3").foreach { lineSep => + testLineSeparator(lineSep) + } + private def testFile: String = { Thread.currentThread().getContextClassLoader.getResource("test-data/text-suite.txt").toString } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index 60a4638f610b3..965f0e14684a6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -95,7 +95,7 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister { val projection = new InterpretedProjection(outputAttributes, inputAttributes) val unsafeRowIterator = - new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value).map { line => + new HadoopFileLinesReader(file, "\n", broadcastedHadoopConf.value.value).map { line => val record = line.toString new GenericInternalRow(record.split(",", -1).zip(fieldTypes).map { case (v, dataType) =>