diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMOptions.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMOptions.scala
index 6900b4153a7eb..8be23894a732f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMOptions.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMOptions.scala
@@ -41,6 +41,9 @@ private[libsvm] class LibSVMOptions(@transient private val parameters: CaseInsen
case o => throw new IllegalArgumentException(s"Invalid value `$o` for parameter " +
s"`$VECTOR_TYPE`. Expected types are `sparse` and `dense`.")
}
+
+ val lineSeparator: String = parameters.getOrElse(LINE_SEPARATOR, "\n")
+ require(lineSeparator.nonEmpty, s"'$LINE_SEPARATOR' cannot be an empty string.")
}
private[libsvm] object LibSVMOptions {
@@ -48,4 +51,5 @@ private[libsvm] object LibSVMOptions {
val VECTOR_TYPE = "vectorType"
val DENSE_VECTOR_TYPE = "dense"
val SPARSE_VECTOR_TYPE = "sparse"
+ val LINE_SEPARATOR = "lineSep"
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
index 4e84ff044f55e..4c928ce4a753e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
@@ -41,6 +41,7 @@ import org.apache.spark.util.SerializableConfiguration
private[libsvm] class LibSVMOutputWriter(
path: String,
dataSchema: StructType,
+ lineSeparator: String,
context: TaskAttemptContext)
extends OutputWriter {
@@ -57,7 +58,7 @@ private[libsvm] class LibSVMOutputWriter(
writer.write(s" ${i + 1}:$v")
}
- writer.write('\n')
+ writer.write(lineSeparator)
}
override def close(): Unit = {
@@ -100,7 +101,7 @@ private[libsvm] class LibSVMFileFormat
"'numFeatures' option to avoid the extra scan.")
val paths = files.map(_.getPath.toUri.toString)
- val parsed = MLUtils.parseLibSVMFile(sparkSession, paths)
+ val parsed = MLUtils.parseLibSVMFile(sparkSession, paths, libSVMOptions.lineSeparator)
MLUtils.computeNumFeatures(parsed)
}
@@ -120,12 +121,13 @@ private[libsvm] class LibSVMFileFormat
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory = {
verifySchema(dataSchema, true)
+ val libSVMOptions = new LibSVMOptions(options)
new OutputWriterFactory {
override def newInstance(
path: String,
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
- new LibSVMOutputWriter(path, dataSchema, context)
+ new LibSVMOutputWriter(path, dataSchema, libSVMOptions.lineSeparator, context)
}
override def getFileExtension(context: TaskAttemptContext): String = {
@@ -153,7 +155,8 @@ private[libsvm] class LibSVMFileFormat
sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
(file: PartitionedFile) => {
- val linesReader = new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value)
+ val linesReader = new HadoopFileLinesReader(
+ file, libSVMOptions.lineSeparator, broadcastedHadoopConf.value.value)
Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close()))
val points = linesReader
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
index 14af8b5c73870..187bd4555de16 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
@@ -30,7 +30,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD}
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.execution.datasources.DataSource
-import org.apache.spark.sql.execution.datasources.text.TextFileFormat
+import org.apache.spark.sql.execution.datasources.text.{TextFileFormat, TextOptions}
import org.apache.spark.sql.functions._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.random.BernoulliCellSampler
@@ -105,12 +105,17 @@ object MLUtils extends Logging {
}
private[spark] def parseLibSVMFile(
- sparkSession: SparkSession, paths: Seq[String]): RDD[(Double, Array[Int], Array[Double])] = {
+ sparkSession: SparkSession,
+ paths: Seq[String],
+ lineSeparator: String): RDD[(Double, Array[Int], Array[Double])] = {
+ val textOptions = Map(TextOptions.LINE_SEPARATOR -> lineSeparator)
+
val lines = sparkSession.baseRelationToDataFrame(
DataSource.apply(
sparkSession,
paths = paths,
- className = classOf[TextFileFormat].getName
+ className = classOf[TextFileFormat].getName,
+ options = textOptions
).resolveRelation(checkFilesExist = false))
.select("value")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
index 3eabff434e8de..3c053cac1f181 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
@@ -184,4 +184,62 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
spark.sql("DROP TABLE IF EXISTS libsvmTable")
}
}
+
+ def testLineSeparator(lineSep: String): Unit = {
+ test(s"SPARK-21289: Support line separator - lineSep: '$lineSep'") {
+ val data = Seq(
+ "1.0 1:1.0 3:2.0 5:3.0", "0.0", "0.0", "0.0 2:4.0 4:5.0 6:6.0").mkString(lineSep)
+ val dataWithTrailingLineSep = s"$data$lineSep"
+
+ Seq(data, dataWithTrailingLineSep).foreach { lines =>
+ val path0 = new File(tempDir.getCanonicalPath, "write0")
+ val path1 = new File(tempDir.getCanonicalPath, "write1")
+ try {
+ // Read
+ java.nio.file.Files.write(path0.toPath, lines.getBytes(StandardCharsets.UTF_8))
+ val df = spark.read
+ .option("lineSep", lineSep)
+ .format("libsvm")
+ .load(path0.getAbsolutePath)
+
+ assert(df.columns(0) == "label")
+ assert(df.columns(1) == "features")
+
+ val results = df.collect()
+
+ assert(results.map(_.getDouble(0)).toSet == Seq(1.0, 0.0, 0.0, 0.0).toSet)
+
+ val actual = results.map(_.getAs[SparseVector](1))
+ val expected = Seq(
+ Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0))),
+ Vectors.sparse(6, Nil),
+ Vectors.sparse(6, Nil),
+ Vectors.sparse(6, Seq((1, 4.0), (3, 5.0), (5, 6.0))))
+ assert(actual.toSet == expected.toSet)
+
+ // Write
+ df.coalesce(1)
+ .write.option("lineSep", lineSep).format("libsvm").save(path1.getAbsolutePath)
+ val partFile = Utils.recursiveList(path1).filter(f => f.getName.startsWith("part-")).head
+ val readBack = new String(
+ java.nio.file.Files.readAllBytes(partFile.toPath), StandardCharsets.UTF_8)
+ assert(readBack == dataWithTrailingLineSep)
+
+ // Roundtrip
+ val readBackDF = spark.read
+ .option("lineSep", lineSep)
+ .format("libsvm")
+ .load(path1.getAbsolutePath)
+ assert(df.collect().toSet == readBackDF.collect().toSet)
+ } finally {
+ Utils.deleteRecursively(path0)
+ Utils.deleteRecursively(path1)
+ }
+ }
+ }
+ }
+
+ Seq("123!!@", "^", "&@").foreach { lineSep =>
+ testLineSeparator(lineSep)
+ }
}
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index f3092918abb54..6173e9e45ffff 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -176,7 +176,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None,
allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None,
mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None,
- multiLine=None, allowUnquotedControlChars=None):
+ multiLine=None, allowUnquotedControlChars=None, lineSep=None):
"""
Loads JSON files and returns the results as a :class:`DataFrame`.
@@ -237,6 +237,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
:param allowUnquotedControlChars: allows JSON Strings to contain unquoted control
characters (ASCII characters with value less than 32,
including tab and line feed characters) or not.
+ :param lineSep: defines the line separator that should be used for parsing. If None is
+ set, it uses ``\\n`` by default, covering ``\\r``, ``\\r\\n`` and ``\\n``.
>>> df1 = spark.read.json('python/test_support/sql/people.json')
>>> df1.dtypes
@@ -254,7 +256,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter,
mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat,
timestampFormat=timestampFormat, multiLine=multiLine,
- allowUnquotedControlChars=allowUnquotedControlChars)
+ allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep)
if isinstance(path, basestring):
path = [path]
if type(path) == list:
@@ -304,7 +306,7 @@ def parquet(self, *paths):
@ignore_unicode_prefix
@since(1.6)
- def text(self, paths):
+ def text(self, paths, lineSep=None):
"""
Loads text files and returns a :class:`DataFrame` whose schema starts with a
string column named "value", and followed by partitioned columns if there
@@ -313,11 +315,14 @@ def text(self, paths):
Each line in the text file is a new row in the resulting DataFrame.
:param paths: string, or list of strings, for input path(s).
+ :param lineSep: defines the line separator that should be used for parsing. If None is
+ set, it uses ``\\n`` by default, covering ``\\r``, ``\\r\\n`` and ``\\n``.
>>> df = spark.read.text('python/test_support/sql/text-test.txt')
>>> df.collect()
[Row(value=u'hello'), Row(value=u'this')]
"""
+ self._set_opts(lineSep=lineSep)
if isinstance(paths, basestring):
paths = [paths]
return self._df(self._jreader.text(self._spark._sc._jvm.PythonUtils.toSeq(paths)))
@@ -342,7 +347,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
:param sep: sets the single character as a separator for each field and value.
If None is set, it uses the default value, ``,``.
:param encoding: decodes the CSV files by the given encoding type. If None is set,
- it uses the default value, ``UTF-8``.
+ it uses the default value, ``UTF-8``. Note that, currently this option
+ does not support non-ascii compatible encodings.
:param quote: sets the single character used for escaping quoted values where the
separator can be part of the value. If None is set, it uses the default
value, ``"``. If you would like to turn off quotations, you need to set an
@@ -730,7 +736,8 @@ def saveAsTable(self, name, format=None, mode=None, partitionBy=None, **options)
self._jwrite.saveAsTable(name)
@since(1.4)
- def json(self, path, mode=None, compression=None, dateFormat=None, timestampFormat=None):
+ def json(self, path, mode=None, compression=None, dateFormat=None, timestampFormat=None,
+ lineSep=None):
"""Saves the content of the :class:`DataFrame` in JSON format
(`JSON Lines text format or newline-delimited JSON `_) at the
specified path.
@@ -753,12 +760,15 @@ def json(self, path, mode=None, compression=None, dateFormat=None, timestampForm
formats follow the formats at ``java.text.SimpleDateFormat``.
This applies to timestamp type. If None is set, it uses the
default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``.
+ :param lineSep: defines the line separator that should be used for writing. If None is
+ set, it uses the default value, ``\\n``.
>>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data'))
"""
self.mode(mode)
self._set_opts(
- compression=compression, dateFormat=dateFormat, timestampFormat=timestampFormat)
+ compression=compression, dateFormat=dateFormat, timestampFormat=timestampFormat,
+ lineSep=lineSep)
self._jwrite.json(path)
@since(1.4)
@@ -788,18 +798,20 @@ def parquet(self, path, mode=None, partitionBy=None, compression=None):
self._jwrite.parquet(path)
@since(1.6)
- def text(self, path, compression=None):
+ def text(self, path, compression=None, lineSep=None):
"""Saves the content of the DataFrame in a text file at the specified path.
:param path: the path in any Hadoop supported file system
:param compression: compression codec to use when saving to file. This can be one of the
known case-insensitive shorten names (none, bzip2, gzip, lz4,
snappy and deflate).
+ :param lineSep: defines the line separator that should be used for writing. If None is
+ set, it uses the default value, ``\\n``.
The DataFrame must have only one column that is of string type.
Each row becomes a new line in the output file.
"""
- self._set_opts(compression=compression)
+ self._set_opts(compression=compression, lineSep=lineSep)
self._jwrite.text(path)
@since(2.0)
diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py
index 0cf702143c773..adce8b7c434a2 100644
--- a/python/pyspark/sql/streaming.py
+++ b/python/pyspark/sql/streaming.py
@@ -407,7 +407,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None,
allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None,
mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None,
- multiLine=None, allowUnquotedControlChars=None):
+ multiLine=None, allowUnquotedControlChars=None, lineSep=None):
"""
Loads a JSON file stream and returns the results as a :class:`DataFrame`.
@@ -470,6 +470,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
:param allowUnquotedControlChars: allows JSON Strings to contain unquoted control
characters (ASCII characters with value less than 32,
including tab and line feed characters) or not.
+ :param lineSep: defines the line separator that should be used for parsing. If None is
+ set, it uses ``\\n`` by default, covering ``\\r``, ``\\r\\n`` and ``\\n``.
>>> json_sdf = spark.readStream.json(tempfile.mkdtemp(), schema = sdf_schema)
>>> json_sdf.isStreaming
@@ -484,7 +486,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter,
mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat,
timestampFormat=timestampFormat, multiLine=multiLine,
- allowUnquotedControlChars=allowUnquotedControlChars)
+ allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep)
if isinstance(path, basestring):
return self._df(self._jreader.json(path))
else:
@@ -514,7 +516,7 @@ def parquet(self, path):
@ignore_unicode_prefix
@since(2.0)
- def text(self, path):
+ def text(self, path, lineSep=None):
"""
Loads a text file stream and returns a :class:`DataFrame` whose schema starts with a
string column named "value", and followed by partitioned columns if there
@@ -525,6 +527,8 @@ def text(self, path):
.. note:: Evolving.
:param paths: string, or list of strings, for input path(s).
+ :param lineSep: defines the line separator that should be used for parsing. If None is
+ set, it uses ``\\n`` by default, covering ``\\r``, ``\\r\\n`` and ``\\n``.
>>> text_sdf = spark.readStream.text(tempfile.mkdtemp())
>>> text_sdf.isStreaming
@@ -532,6 +536,7 @@ def text(self, path):
>>> "value" in str(text_sdf.schema)
True
"""
+ self._set_opts(lineSep=lineSep)
if isinstance(path, basestring):
return self._df(self._jreader.text(path))
else:
@@ -558,7 +563,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
:param sep: sets the single character as a separator for each field and value.
If None is set, it uses the default value, ``,``.
:param encoding: decodes the CSV files by the given encoding type. If None is set,
- it uses the default value, ``UTF-8``.
+ it uses the default value, ``UTF-8``. Note that, currently this option
+ does not support non-ascii compatible encodings.
:param quote: sets the single character used for escaping quoted values where the
separator can be part of the value. If None is set, it uses the default
value, ``"``. If you would like to turn off quotations, you need to set an
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 98afae662b42d..8fa0f0c2360c7 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -508,12 +508,51 @@ def test_non_existed_udaf(self):
self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udaf",
lambda: spark.udf.registerJavaUDAF("udaf1", "non_existed_udaf"))
- def test_multiLine_json(self):
+ def test_linesep_text(self):
+ df = self.spark.read.text("python/test_support/sql/ages_newlines.csv", lineSep=",")
+ expected = [Row(value=u'Joe'), Row(value=u'20'), Row(value=u'"Hi'),
+ Row(value=u'\nI am Jeo"\nTom'), Row(value=u'30'),
+ Row(value=u'"My name is Tom"\nHyukjin'), Row(value=u'25'),
+ Row(value=u'"I am Hyukjin\n\nI love Spark!"\n')]
+ self.assertEqual(df.collect(), expected)
+
+ tpath = tempfile.mkdtemp()
+ shutil.rmtree(tpath)
+ try:
+ df.write.text(tpath, lineSep="!")
+ expected = [Row(value=u'Joe!20!"Hi!'), Row(value=u'I am Jeo"'),
+ Row(value=u'Tom!30!"My name is Tom"'),
+ Row(value=u'Hyukjin!25!"I am Hyukjin'),
+ Row(value=u''), Row(value=u'I love Spark!"'),
+ Row(value=u'!')]
+ readback = self.spark.read.text(tpath)
+ self.assertEqual(readback.collect(), expected)
+ finally:
+ shutil.rmtree(tpath)
+
+ def test_multiline_json(self):
people1 = self.spark.read.json("python/test_support/sql/people.json")
people_array = self.spark.read.json("python/test_support/sql/people_array.json",
multiLine=True)
self.assertEqual(people1.collect(), people_array.collect())
+ def test_linesep_json(self):
+ df = self.spark.read.json("python/test_support/sql/people.json", lineSep=",")
+ expected = [Row(_corrupt_record=None, name=u'Michael'),
+ Row(_corrupt_record=u' "age":30}\n{"name":"Justin"', name=None),
+ Row(_corrupt_record=u' "age":19}\n', name=None)]
+ self.assertEqual(df.collect(), expected)
+
+ tpath = tempfile.mkdtemp()
+ shutil.rmtree(tpath)
+ try:
+ df = self.spark.read.json("python/test_support/sql/people.json")
+ df.write.json(tpath, lineSep="!!")
+ readback = self.spark.read.json(tpath, lineSep="!!")
+ self.assertEqual(readback.collect(), df.collect())
+ finally:
+ shutil.rmtree(tpath)
+
def test_multiline_csv(self):
ages_newlines = self.spark.read.csv(
"python/test_support/sql/ages_newlines.csv", multiLine=True)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
index 652412b34478a..e36d5ada9398f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
@@ -85,6 +85,9 @@ private[sql] class JSONOptions(
val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false)
+ val lineSeparator: String = parameters.getOrElse("lineSep", "\n")
+ require(lineSeparator.nonEmpty, "'lineSep' cannot be an empty string.")
+
/** Sets config options on a Jackson [[JsonFactory]]. */
def setJacksonOptions(factory: JsonFactory): Unit = {
factory.configure(JsonParser.Feature.ALLOW_COMMENTS, allowComments)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala
index eb06e4f304f0a..ca7097f2496ed 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.json
import java.io.Writer
+import java.nio.charset.StandardCharsets
import com.fasterxml.jackson.core._
@@ -74,6 +75,8 @@ private[sql] class JacksonGenerator(
private val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null)
+ private val lineSep = options.lineSeparator
+
private def makeWriter(dataType: DataType): ValueWriter = dataType match {
case NullType =>
(row: SpecializedGetters, ordinal: Int) =>
@@ -251,5 +254,5 @@ private[sql] class JacksonGenerator(
mapType = dataType.asInstanceOf[MapType]))
}
- def writeLineEnding(): Unit = gen.writeRaw('\n')
+ def writeLineEnding(): Unit = gen.writeRaw(lineSep)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 17966eecfc051..ae87701970750 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -370,6 +370,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* `java.text.SimpleDateFormat`. This applies to timestamp type.
*
`multiLine` (default `false`): parse one record, which may span multiple lines,
* per file
+ * `lineSep` (default is `\n`, covering `\r`, `\r\n` and `\n`): defines the line separator
+ * that should be used for parsing.
*
*
* @since 2.0.0
@@ -515,7 +517,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* `sep` (default `,`): sets the single character as a separator for each
* field and value.
* `encoding` (default `UTF-8`): decodes the CSV files by the given encoding
- * type.
+ * type. Note that, currently this option does not support non-ascii compatible encodings.
* `quote` (default `"`): sets the single character used for escaping quoted values where
* the separator can be part of the value. If you would like to turn off quotations, you need to
* set not `null` but an empty string. This behaviour is different from
@@ -655,6 +657,12 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* spark.read().text("/path/to/spark/README.md")
* }}}
*
+ * You can set the following text-specific options to deal with text files:
+ *
+ * - `lineSep` (default is `\n`, covering `\r`, `\r\n` and `\n`): defines the line separator
+ * that should be used for parsing.
+ *
+ *
* @param paths input paths
* @since 1.6.0
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index 8d95b24c00619..66bd357cf2a6f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -503,6 +503,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
* `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that
* indicates a timestamp format. Custom date formats follow the formats at
* `java.text.SimpleDateFormat`. This applies to timestamp type.
+ * `lineSep` (default is `\n`): defines the line separator that should
+ * be used for writing.
*
*
* @since 1.4.0
@@ -572,6 +574,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
* `compression` (default `null`): compression codec to use when saving to file. This can be
* one of the known case-insensitive shorten names (`none`, `bzip2`, `gzip`, `lz4`,
* `snappy` and `deflate`).
+ * `lineSep` (default is `\n`): defines the line separator that should
+ * be used for writing.
*
*
* @since 1.6.0
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala
index 83cf26c63a175..18d261b0f5bab 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala
@@ -30,9 +30,16 @@ import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
/**
* An adaptor from a [[PartitionedFile]] to an [[Iterator]] of [[Text]], which are all of the lines
* in that file.
+ *
+ * @param file A part (i.e. "block") of a single file that should be read line by line.
+ * @param lineSeparator A line separator that should be used for each line. If the value is `\n`,
+ * it covers `\r`, `\r\n` and `\n`.
+ * @param conf Hadoop configuration
*/
class HadoopFileLinesReader(
- file: PartitionedFile, conf: Configuration) extends Iterator[Text] with Closeable {
+ file: PartitionedFile,
+ lineSeparator: String,
+ conf: Configuration) extends Iterator[Text] with Closeable {
private val iterator = {
val fileSplit = new FileSplit(
new Path(new URI(file.filePath)),
@@ -42,7 +49,12 @@ class HadoopFileLinesReader(
Array.empty)
val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0)
val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId)
- val reader = new LineRecordReader()
+ val reader = if (lineSeparator != "\n") {
+ new LineRecordReader(lineSeparator.getBytes("UTF-8"))
+ } else {
+ // This behavior follows Hive. `\n` covers `\r`, `\r\n` and `\n`.
+ new LineRecordReader()
+ }
reader.initialize(fileSplit, hadoopAttemptContext)
new RecordReaderIterator(reader)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
index 2031381dd2e10..84435feda0a8c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
@@ -32,9 +32,8 @@ import org.apache.spark.input.{PortableDataStream, StreamInputFormat}
import org.apache.spark.rdd.{BinaryFileRDD, RDD}
import org.apache.spark.sql.{Dataset, Encoders, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.datasources._
-import org.apache.spark.sql.execution.datasources.text.TextFileFormat
+import org.apache.spark.sql.execution.datasources.text.{TextFileFormat, TextOptions}
import org.apache.spark.sql.types.StructType
/**
@@ -129,7 +128,7 @@ object TextInputCSVDataSource extends CSVDataSource {
parser: UnivocityParser,
schema: StructType): Iterator[InternalRow] = {
val lines = {
- val linesReader = new HadoopFileLinesReader(file, conf)
+ val linesReader = new HadoopFileLinesReader(file, parser.options.lineSeparator, conf)
Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close()))
linesReader.map { line =>
new String(line.getBytes, 0, line.getLength, parser.options.charset)
@@ -178,13 +177,16 @@ object TextInputCSVDataSource extends CSVDataSource {
sparkSession: SparkSession,
inputPaths: Seq[FileStatus],
options: CSVOptions): Dataset[String] = {
+ val textOptions = Map(TextOptions.LINE_SEPARATOR -> options.lineSeparator)
+
val paths = inputPaths.map(_.getPath.toString)
if (Charset.forName(options.charset) == StandardCharsets.UTF_8) {
sparkSession.baseRelationToDataFrame(
DataSource.apply(
sparkSession,
paths = paths,
- className = classOf[TextFileFormat].getName
+ className = classOf[TextFileFormat].getName,
+ options = textOptions
).resolveRelation(checkFilesExist = false))
.select("value").as[String](Encoders.STRING)
} else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
index a13a5a34b4a84..f4bcaa37a2d57 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
@@ -138,6 +138,9 @@ class CSVOptions(
val quoteAll = getBool("quoteAll", false)
+ val lineSeparator: String = parameters.getOrElse("lineSep", "\n")
+ require(lineSeparator.nonEmpty, "'lineSep' cannot be an empty string.")
+
val inputBufferSize = 128
val isCommentSet = this.comment != '\u0000'
@@ -149,6 +152,9 @@ class CSVOptions(
format.setQuote(quote)
format.setQuoteEscape(escape)
format.setComment(comment)
+ if (lineSeparator != "\n") {
+ format.setLineSeparator(lineSeparator)
+ }
writerSettings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceFlagInWrite)
writerSettings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceFlagInWrite)
writerSettings.setNullValue(nullValue)
@@ -166,6 +172,9 @@ class CSVOptions(
format.setQuote(quote)
format.setQuoteEscape(escape)
format.setComment(comment)
+ if (lineSeparator != "\n") {
+ format.setLineSeparator(lineSeparator)
+ }
settings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceInRead)
settings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceInRead)
settings.setReadInputOnSeparateThread(false)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
index 8b7c2709afde1..90c910f1e438f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
@@ -34,7 +34,7 @@ import org.apache.spark.sql.{AnalysisException, Dataset, Encoders, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions}
import org.apache.spark.sql.execution.datasources._
-import org.apache.spark.sql.execution.datasources.text.TextFileFormat
+import org.apache.spark.sql.execution.datasources.text.{TextFileFormat, TextOptions}
import org.apache.spark.sql.types.StructType
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
@@ -91,7 +91,8 @@ object TextInputJsonDataSource extends JsonDataSource {
sparkSession: SparkSession,
inputPaths: Seq[FileStatus],
parsedOptions: JSONOptions): StructType = {
- val json: Dataset[String] = createBaseDataset(sparkSession, inputPaths)
+ val json: Dataset[String] = createBaseDataset(
+ sparkSession, inputPaths, parsedOptions.lineSeparator)
inferFromDataset(json, parsedOptions)
}
@@ -103,13 +104,17 @@ object TextInputJsonDataSource extends JsonDataSource {
private def createBaseDataset(
sparkSession: SparkSession,
- inputPaths: Seq[FileStatus]): Dataset[String] = {
+ inputPaths: Seq[FileStatus],
+ lineSeparator: String): Dataset[String] = {
+ val textOptions = Map(TextOptions.LINE_SEPARATOR -> lineSeparator)
+
val paths = inputPaths.map(_.getPath.toString)
sparkSession.baseRelationToDataFrame(
DataSource.apply(
sparkSession,
paths = paths,
- className = classOf[TextFileFormat].getName
+ className = classOf[TextFileFormat].getName,
+ options = textOptions
).resolveRelation(checkFilesExist = false))
.select("value").as(Encoders.STRING)
}
@@ -119,7 +124,7 @@ object TextInputJsonDataSource extends JsonDataSource {
file: PartitionedFile,
parser: JacksonParser,
schema: StructType): Iterator[InternalRow] = {
- val linesReader = new HadoopFileLinesReader(file, conf)
+ val linesReader = new HadoopFileLinesReader(file, parser.options.lineSeparator, conf)
Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close()))
val safeParser = new FailureSafeParser[Text](
input => parser.parse(input, CreateJacksonParser.text, textToUTF8String),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala
index d0690445d7672..3ae7850380f04 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution.datasources.text
+import java.nio.charset.StandardCharsets
+
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
@@ -77,7 +79,7 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister {
path: String,
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
- new TextOutputWriter(path, dataSchema, context)
+ new TextOutputWriter(path, dataSchema, textOptions.lineSeparator, context)
}
override def getFileExtension(context: TaskAttemptContext): String = {
@@ -98,11 +100,14 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister {
requiredSchema.length <= 1,
"Text data source only produces a single data column named \"value\".")
+ val textOptions = new TextOptions(options)
+
val broadcastedHadoopConf =
sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
(file: PartitionedFile) => {
- val reader = new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value)
+ val reader = new HadoopFileLinesReader(
+ file, textOptions.lineSeparator, broadcastedHadoopConf.value.value)
Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => reader.close()))
if (requiredSchema.isEmpty) {
@@ -128,9 +133,12 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister {
class TextOutputWriter(
path: String,
dataSchema: StructType,
+ lineSeparator: String,
context: TaskAttemptContext)
extends OutputWriter {
+ private val lineSep = lineSeparator.getBytes(StandardCharsets.UTF_8)
+
private val writer = CodecStreams.createOutputStream(context, new Path(path))
override def write(row: InternalRow): Unit = {
@@ -138,7 +146,7 @@ class TextOutputWriter(
val utf8string = row.getUTF8String(0)
utf8string.writeTo(writer)
}
- writer.write('\n')
+ writer.write(lineSep)
}
override def close(): Unit = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala
index 49bd7382f9cf3..afa82e8ee6696 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala
@@ -33,8 +33,12 @@ private[text] class TextOptions(@transient private val parameters: CaseInsensiti
* Compression codec to use.
*/
val compressionCodec = parameters.get(COMPRESSION).map(CompressionCodecs.getCodecClassName)
+
+ val lineSeparator: String = parameters.getOrElse(LINE_SEPARATOR, "\n")
+ require(lineSeparator.nonEmpty, s"'$LINE_SEPARATOR' cannot be an empty string.")
}
-private[text] object TextOptions {
+private[spark] object TextOptions {
val COMPRESSION = "compression"
+ val LINE_SEPARATOR = "lineSep"
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
index a42e28053a96a..9e6efc2c51589 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
@@ -222,6 +222,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
* `java.text.SimpleDateFormat`. This applies to timestamp type.
* `multiLine` (default `false`): parse one record, which may span multiple lines,
* per file
+ * `lineSep` (default is `\n`, covering `\r`, `\r\n` and `\n`): defines the line separator
+ * that should be used for parsing.
*
*
* @since 2.0.0
@@ -242,7 +244,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
* `sep` (default `,`): sets the single character as a separator for each
* field and value.
* `encoding` (default `UTF-8`): decodes the CSV files by the given encoding
- * type.
+ * type. Note that, currently this option does not support non-ascii compatible encodings.
* `quote` (default `"`): sets the single character used for escaping quoted values where
* the separator can be part of the value. If you would like to turn off quotations, you need to
* set not `null` but an empty string. This behaviour is different form
@@ -334,6 +336,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
*
* - `maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be
* considered in every trigger.
+ * - `lineSep` (default is `\n`, covering `\r`, `\r\n` and `\n`): defines the line separator
+ * that should be used for parsing.
*
*
* @since 2.0.0
@@ -360,6 +364,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
*
* - `maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be
* considered in every trigger.
+ * - `lineSep` (default is `\n`, covering `\r`, `\r\n` and `\n`): defines the line separator
+ * that should be used for parsing.
*
*
* @param path input path
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
index e439699605abb..f51161b83e701 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
@@ -18,7 +18,8 @@
package org.apache.spark.sql.execution.datasources.csv
import java.io.File
-import java.nio.charset.UnsupportedCharsetException
+import java.nio.charset.{StandardCharsets, UnsupportedCharsetException}
+import java.nio.file.Files
import java.sql.{Date, Timestamp}
import java.text.SimpleDateFormat
import java.util.Locale
@@ -30,10 +31,10 @@ import org.apache.hadoop.io.compress.GzipCodec
import org.apache.spark.SparkException
import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, UDT}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
-import org.apache.spark.sql.functions.{col, regexp_replace}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
import org.apache.spark.sql.types._
+import org.apache.spark.util.Utils
class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
import testImplicits._
@@ -1245,4 +1246,55 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
Row("0,2013-111-11 12:13:14") :: Row(null) :: Nil
)
}
+
+ def testLineSeparator(lineSep: String, multiLine: Boolean): Unit = {
+ test(s"SPARK-21289: Support line separator - lineSep: '$lineSep' and multiLine: $multiLine") {
+ // Read
+ val data = Seq("a,b", "1,\"a\nd\"", "1,f").mkString(lineSep)
+ val dataWithTrailingLineSep = s"$data$lineSep"
+
+ Seq(data, dataWithTrailingLineSep).foreach { lines =>
+ withTempPath { path =>
+ Files.write(path.toPath, lines.getBytes(StandardCharsets.UTF_8))
+ val df = spark.read
+ .option("multiLine", multiLine)
+ .option("lineSep", lineSep)
+ .option("header", true)
+ .option("inferSchema", true)
+ .csv(path.getAbsolutePath)
+
+ val expectedSchema =
+ StructType(StructField("a", IntegerType) :: StructField("b", StringType) :: Nil)
+ checkAnswer(df, Seq((1, "a\nd"), (1, "f")).toDF())
+ assert(df.schema === expectedSchema)
+ }
+ }
+
+ // Write
+ withTempPath { path =>
+ Seq("a", "b", "c").toDF().coalesce(1)
+ .write.option("lineSep", lineSep).csv(path.getAbsolutePath)
+ val partFile = Utils.recursiveList(path).filter(f => f.getName.startsWith("part-")).head
+ val readBack = new String(Files.readAllBytes(partFile.toPath), StandardCharsets.UTF_8)
+ assert(readBack === s"a${lineSep}b${lineSep}c$lineSep")
+ }
+
+ // Roundtrip
+ withTempPath { path =>
+ val df = Seq("a", "b", "c").toDF()
+ df.write.option("lineSep", lineSep).csv(path.getAbsolutePath)
+ val readBack = spark.read
+ .option("multiLine", multiLine)
+ .option("lineSep", lineSep)
+ .csv(path.getAbsolutePath)
+ checkAnswer(df, readBack)
+ }
+ }
+ }
+
+ Seq("|", "^", "::", "\r\n").foreach { lineSep =>
+ Seq(true, false).foreach { multiLine =>
+ testLineSeparator(lineSep, multiLine)
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
index 8c8d41ebf115a..d23fcc4d54f1a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.json
import java.io.{File, StringWriter}
import java.nio.charset.StandardCharsets
+import java.nio.file.Files
import java.sql.{Date, Timestamp}
import java.util.Locale
@@ -2063,4 +2064,51 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
)
}
}
+
+ def testLineSeparator(lineSep: String): Unit = {
+ test(s"SPARK-21289: Support line separator - lineSep: '$lineSep'") {
+ // Read
+ val data =
+ s"""
+ | {"f":
+ |"a", "f0": 1}$lineSep{"f":
+ |
+ |"c", "f0": 2}$lineSep{"f": "d", "f0": 3}
+ """.stripMargin
+ val dataWithTrailingLineSep = s"$data$lineSep"
+
+ Seq(data, dataWithTrailingLineSep).foreach { lines =>
+ withTempPath { path =>
+ Files.write(path.toPath, lines.getBytes(StandardCharsets.UTF_8))
+ val df = spark.read.option("lineSep", lineSep).json(path.getAbsolutePath)
+ val expectedSchema =
+ StructType(StructField("f", StringType) :: StructField("f0", LongType) :: Nil)
+ checkAnswer(df, Seq(("a", 1), ("c", 2), ("d", 3)).toDF())
+ assert(df.schema === expectedSchema)
+ }
+ }
+
+ // Write
+ withTempPath { path =>
+ Seq("a", "b", "c").toDF("value").coalesce(1)
+ .write.option("lineSep", lineSep).json(path.getAbsolutePath)
+ val partFile = Utils.recursiveList(path).filter(f => f.getName.startsWith("part-")).head
+ val readBack = new String(Files.readAllBytes(partFile.toPath), StandardCharsets.UTF_8)
+ assert(
+ readBack === s"""{"value":"a"}$lineSep{"value":"b"}$lineSep{"value":"c"}$lineSep""")
+ }
+
+ // Roundtrip
+ withTempPath { path =>
+ val df = Seq("a", "b", "c").toDF()
+ df.write.option("lineSep", lineSep).json(path.getAbsolutePath)
+ val readBack = spark.read.option("lineSep", lineSep).json(path.getAbsolutePath)
+ checkAnswer(df, readBack)
+ }
+ }
+ }
+
+ Seq("|", "^", "::", "!!!@3").foreach { lineSep =>
+ testLineSeparator(lineSep)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala
index cb7393cdd2b9d..7cf7a16b5d884 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala
@@ -18,6 +18,8 @@
package org.apache.spark.sql.execution.datasources.text
import java.io.File
+import java.nio.charset.StandardCharsets
+import java.nio.file.Files
import org.apache.hadoop.io.SequenceFile.CompressionType
import org.apache.hadoop.io.compress.GzipCodec
@@ -172,6 +174,43 @@ class TextSuite extends QueryTest with SharedSQLContext {
}
}
+ def testLineSeparator(lineSep: String): Unit = {
+ test(s"SPARK-21289: Support line separator - lineSep: '$lineSep'") {
+ // Read
+ val values = Seq("a", "b", "\nc")
+ val data = values.mkString(lineSep)
+ val dataWithTrailingLineSep = s"$data$lineSep"
+ Seq(data, dataWithTrailingLineSep).foreach { lines =>
+ withTempPath { path =>
+ Files.write(path.toPath, lines.getBytes(StandardCharsets.UTF_8))
+ val df = spark.read.option("lineSep", lineSep).text(path.getAbsolutePath)
+ checkAnswer(df, Seq("a", "b", "\nc").toDF())
+ }
+ }
+
+ // Write
+ withTempPath { path =>
+ values.toDF().coalesce(1)
+ .write.option("lineSep", lineSep).text(path.getAbsolutePath)
+ val partFile = Utils.recursiveList(path).filter(f => f.getName.startsWith("part-")).head
+ val readBack = new String(Files.readAllBytes(partFile.toPath), StandardCharsets.UTF_8)
+ assert(readBack === s"a${lineSep}b${lineSep}\nc${lineSep}")
+ }
+
+ // Roundtrip
+ withTempPath { path =>
+ val df = values.toDF()
+ df.write.option("lineSep", lineSep).text(path.getAbsolutePath)
+ val readBack = spark.read.option("lineSep", lineSep).text(path.getAbsolutePath)
+ checkAnswer(df, readBack)
+ }
+ }
+ }
+
+ Seq("|", "^", "::", "!!!@3").foreach { lineSep =>
+ testLineSeparator(lineSep)
+ }
+
private def testFile: String = {
Thread.currentThread().getContextClassLoader.getResource("test-data/text-suite.txt").toString
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
index 60a4638f610b3..965f0e14684a6 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
@@ -95,7 +95,7 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister {
val projection = new InterpretedProjection(outputAttributes, inputAttributes)
val unsafeRowIterator =
- new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value).map { line =>
+ new HadoopFileLinesReader(file, "\n", broadcastedHadoopConf.value.value).map { line =>
val record = line.toString
new GenericInternalRow(record.split(",", -1).zip(fieldTypes).map {
case (v, dataType) =>