diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 726de4a965418..1d2dd4d808930 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -353,7 +353,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None, - samplingRatio=None, enforceSchema=None, emptyValue=None, locale=None): + samplingRatio=None, enforceSchema=None, emptyValue=None, locale=None, lineSep=None): r"""Loads a CSV file and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -453,6 +453,9 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non :param locale: sets a locale as language tag in IETF BCP 47 format. If None is set, it uses the default value, ``en-US``. For instance, ``locale`` is used while parsing dates and timestamps. + :param lineSep: defines the line separator that should be used for parsing. If None is + set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``. + Maximum length is 1 character. >>> df = spark.read.csv('python/test_support/sql/ages.csv') >>> df.dtypes @@ -472,7 +475,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine, charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, samplingRatio=samplingRatio, - enforceSchema=enforceSchema, emptyValue=emptyValue, locale=locale) + enforceSchema=enforceSchema, emptyValue=emptyValue, locale=locale, lineSep=lineSep) if isinstance(path, basestring): path = [path] if type(path) == list: @@ -868,7 +871,7 @@ def text(self, path, compression=None, lineSep=None): def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=None, header=None, nullValue=None, escapeQuotes=None, quoteAll=None, dateFormat=None, timestampFormat=None, ignoreLeadingWhiteSpace=None, ignoreTrailingWhiteSpace=None, - charToEscapeQuoteEscaping=None, encoding=None, emptyValue=None): + charToEscapeQuoteEscaping=None, encoding=None, emptyValue=None, lineSep=None): r"""Saves the content of the :class:`DataFrame` in CSV format at the specified path. :param path: the path in any Hadoop supported file system @@ -922,6 +925,8 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No the default UTF-8 charset will be used. :param emptyValue: sets the string representation of an empty value. If None is set, it uses the default value, ``""``. + :param lineSep: defines the line separator that should be used for writing. If None is + set, it uses the default value, ``\\n``. Maximum length is 1 character. >>> df.write.csv(os.path.join(tempfile.mkdtemp(), 'data')) """ @@ -932,7 +937,7 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No ignoreLeadingWhiteSpace=ignoreLeadingWhiteSpace, ignoreTrailingWhiteSpace=ignoreTrailingWhiteSpace, charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, - encoding=encoding, emptyValue=emptyValue) + encoding=encoding, emptyValue=emptyValue, lineSep=lineSep) self._jwrite.csv(path) @since(1.5) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 58ca7b83e5b2b..d92b0d5677e25 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -576,7 +576,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None, - enforceSchema=None, emptyValue=None, locale=None): + enforceSchema=None, emptyValue=None, locale=None, lineSep=None): r"""Loads a CSV file stream and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -675,6 +675,9 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non :param locale: sets a locale as language tag in IETF BCP 47 format. If None is set, it uses the default value, ``en-US``. For instance, ``locale`` is used while parsing dates and timestamps. + :param lineSep: defines the line separator that should be used for parsing. If None is + set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``. + Maximum length is 1 character. >>> csv_sdf = spark.readStream.csv(tempfile.mkdtemp(), schema = sdf_schema) >>> csv_sdf.isStreaming @@ -692,7 +695,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine, charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, enforceSchema=enforceSchema, - emptyValue=emptyValue, locale=locale) + emptyValue=emptyValue, locale=locale, lineSep=lineSep) if isinstance(path, basestring): return self._df(self._jreader.csv(path)) else: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala index 6bb50b42a369c..94bdb72d675d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala @@ -192,6 +192,20 @@ class CSVOptions( */ val emptyValueInWrite = emptyValue.getOrElse("\"\"") + /** + * A string between two consecutive JSON records. + */ + val lineSeparator: Option[String] = parameters.get("lineSep").map { sep => + require(sep.nonEmpty, "'lineSep' cannot be an empty string.") + require(sep.length == 1, "'lineSep' can contain only 1 character.") + sep + } + + val lineSeparatorInRead: Option[Array[Byte]] = lineSeparator.map { lineSep => + lineSep.getBytes(charset) + } + val lineSeparatorInWrite: Option[String] = lineSeparator + def asWriterSettings: CsvWriterSettings = { val writerSettings = new CsvWriterSettings() val format = writerSettings.getFormat @@ -200,6 +214,8 @@ class CSVOptions( format.setQuoteEscape(escape) charToEscapeQuoteEscaping.foreach(format.setCharToEscapeQuoteEscaping) format.setComment(comment) + lineSeparatorInWrite.foreach(format.setLineSeparator) + writerSettings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceFlagInWrite) writerSettings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceFlagInWrite) writerSettings.setNullValue(nullValue) @@ -216,8 +232,10 @@ class CSVOptions( format.setDelimiter(delimiter) format.setQuote(quote) format.setQuoteEscape(escape) + lineSeparator.foreach(format.setLineSeparator) charToEscapeQuoteEscaping.foreach(format.setCharToEscapeQuoteEscaping) format.setComment(comment) + settings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceInRead) settings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceInRead) settings.setReadInputOnSeparateThread(false) @@ -227,7 +245,10 @@ class CSVOptions( settings.setEmptyValue(emptyValueInRead) settings.setMaxCharsPerColumn(maxCharsPerColumn) settings.setUnescapedQuoteHandling(UnescapedQuoteHandling.STOP_AT_DELIMITER) - settings.setLineSeparatorDetectionEnabled(multiLine == true) + settings.setLineSeparatorDetectionEnabled(lineSeparatorInRead.isEmpty && multiLine) + lineSeparatorInRead.foreach { _ => + settings.setNormalizeLineEndingsWithinQuotes(!multiLine) + } settings } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index f08fd64acd9a1..da88598eed061 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -609,6 +609,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `multiLine` (default `false`): parse one record, which may span multiple lines.
  • *
  • `locale` (default is `en-US`): sets a locale as language tag in IETF BCP 47 format. * For instance, this is used while parsing dates and timestamps.
  • + *
  • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator + * that should be used for parsing. Maximum length is 1 character.
  • * * * @since 2.0.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 29d479f542115..5a807d3d4b93e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -658,6 +658,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * whitespaces from values being written should be skipped. *
  • `ignoreTrailingWhiteSpace` (default `true`): a flag indicating defines whether or not * trailing whitespaces from values being written should be skipped.
  • + *
  • `lineSep` (default `\n`): defines the line separator that should be used for writing. + * Maximum length is 1 character.
  • * * * @since 2.0.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 554baaf1a9b3b..b35b8851918b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -95,7 +95,7 @@ object TextInputCSVDataSource extends CSVDataSource { headerChecker: CSVHeaderChecker, requiredSchema: StructType): Iterator[InternalRow] = { val lines = { - val linesReader = new HadoopFileLinesReader(file, conf) + val linesReader = new HadoopFileLinesReader(file, parser.options.lineSeparatorInRead, conf) Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => linesReader.close())) linesReader.map { line => new String(line.getBytes, 0, line.getLength, parser.options.charset) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index e4250145a1ae2..c8e3e1c191044 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -377,6 +377,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
  • `multiLine` (default `false`): parse one record, which may span multiple lines.
  • *
  • `locale` (default is `en-US`): sets a locale as language tag in IETF BCP 47 format. * For instance, this is used while parsing dates and timestamps.
  • + *
  • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator + * that should be used for parsing. Maximum length is 1 character.
  • * * * @since 2.0.0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index e29cd2aa7c4e6..c275d63d32cc8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources.csv import java.io.File -import java.nio.charset.{Charset, UnsupportedCharsetException} +import java.nio.charset.{Charset, StandardCharsets, UnsupportedCharsetException} import java.nio.file.Files import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat @@ -33,7 +33,7 @@ import org.apache.hadoop.io.compress.GzipCodec import org.apache.log4j.{AppenderSkeleton, LogManager} import org.apache.log4j.spi.LoggingEvent -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, TestUtils} import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf @@ -1880,4 +1880,110 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te } } } + + test("""Support line separator - default value \r, \r\n and \n""") { + val data = "\"a\",1\r\"c\",2\r\n\"d\",3\n" + + withTempPath { path => + Files.write(path.toPath, data.getBytes(StandardCharsets.UTF_8)) + val df = spark.read.option("inferSchema", true).csv(path.getAbsolutePath) + val expectedSchema = + StructType(StructField("_c0", StringType) :: StructField("_c1", IntegerType) :: Nil) + checkAnswer(df, Seq(("a", 1), ("c", 2), ("d", 3)).toDF()) + assert(df.schema === expectedSchema) + } + } + + def testLineSeparator(lineSep: String, encoding: String, inferSchema: Boolean, id: Int): Unit = { + test(s"Support line separator in ${encoding} #${id}") { + // Read + val data = + s""""a",1$lineSep + |c,2$lineSep" + |d",3""".stripMargin + val dataWithTrailingLineSep = s"$data$lineSep" + + Seq(data, dataWithTrailingLineSep).foreach { lines => + withTempPath { path => + Files.write(path.toPath, lines.getBytes(encoding)) + val schema = StructType(StructField("_c0", StringType) + :: StructField("_c1", LongType) :: Nil) + + val expected = Seq(("a", 1), ("\nc", 2), ("\nd", 3)) + .toDF("_c0", "_c1") + Seq(false, true).foreach { multiLine => + val reader = spark + .read + .option("lineSep", lineSep) + .option("multiLine", multiLine) + .option("encoding", encoding) + val df = if (inferSchema) { + reader.option("inferSchema", true).csv(path.getAbsolutePath) + } else { + reader.schema(schema).csv(path.getAbsolutePath) + } + checkAnswer(df, expected) + } + } + } + + // Write + withTempPath { path => + Seq("a", "b", "c").toDF("value").coalesce(1) + .write + .option("lineSep", lineSep) + .option("encoding", encoding) + .csv(path.getAbsolutePath) + val partFile = TestUtils.recursiveList(path).filter(f => f.getName.startsWith("part-")).head + val readBack = new String(Files.readAllBytes(partFile.toPath), encoding) + assert( + readBack === s"a${lineSep}b${lineSep}c${lineSep}") + } + + // Roundtrip + withTempPath { path => + val df = Seq("a", "b", "c").toDF() + df.write + .option("lineSep", lineSep) + .option("encoding", encoding) + .csv(path.getAbsolutePath) + val readBack = spark + .read + .option("lineSep", lineSep) + .option("encoding", encoding) + .csv(path.getAbsolutePath) + checkAnswer(df, readBack) + } + } + } + + // scalastyle:off nonascii + List( + (0, "|", "UTF-8", false), + (1, "^", "UTF-16BE", true), + (2, ":", "ISO-8859-1", true), + (3, "!", "UTF-32LE", false), + (4, 0x1E.toChar.toString, "UTF-8", true), + (5, "아", "UTF-32BE", false), + (6, "у", "CP1251", true), + (8, "\r", "UTF-16LE", true), + (9, "\u000d", "UTF-32BE", false), + (10, "=", "US-ASCII", false), + (11, "$", "utf-32le", true) + ).foreach { case (testNum, sep, encoding, inferSchema) => + testLineSeparator(sep, encoding, inferSchema, testNum) + } + // scalastyle:on nonascii + + test("lineSep restrictions") { + val errMsg1 = intercept[IllegalArgumentException] { + spark.read.option("lineSep", "").csv(testFile(carsFile)).collect + }.getMessage + assert(errMsg1.contains("'lineSep' cannot be an empty string")) + + val errMsg2 = intercept[IllegalArgumentException] { + spark.read.option("lineSep", "123").csv(testFile(carsFile)).collect + }.getMessage + assert(errMsg2.contains("'lineSep' can contain only 1 character")) + } }