diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 726de4a965418..1d2dd4d808930 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -353,7 +353,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None, - samplingRatio=None, enforceSchema=None, emptyValue=None, locale=None): + samplingRatio=None, enforceSchema=None, emptyValue=None, locale=None, lineSep=None): r"""Loads a CSV file and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -453,6 +453,9 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non :param locale: sets a locale as language tag in IETF BCP 47 format. If None is set, it uses the default value, ``en-US``. For instance, ``locale`` is used while parsing dates and timestamps. + :param lineSep: defines the line separator that should be used for parsing. If None is + set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``. + Maximum length is 1 character. >>> df = spark.read.csv('python/test_support/sql/ages.csv') >>> df.dtypes @@ -472,7 +475,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine, charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, samplingRatio=samplingRatio, - enforceSchema=enforceSchema, emptyValue=emptyValue, locale=locale) + enforceSchema=enforceSchema, emptyValue=emptyValue, locale=locale, lineSep=lineSep) if isinstance(path, basestring): path = [path] if type(path) == list: @@ -868,7 +871,7 @@ def text(self, path, compression=None, lineSep=None): def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=None, header=None, nullValue=None, escapeQuotes=None, quoteAll=None, dateFormat=None, timestampFormat=None, ignoreLeadingWhiteSpace=None, ignoreTrailingWhiteSpace=None, - charToEscapeQuoteEscaping=None, encoding=None, emptyValue=None): + charToEscapeQuoteEscaping=None, encoding=None, emptyValue=None, lineSep=None): r"""Saves the content of the :class:`DataFrame` in CSV format at the specified path. :param path: the path in any Hadoop supported file system @@ -922,6 +925,8 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No the default UTF-8 charset will be used. :param emptyValue: sets the string representation of an empty value. If None is set, it uses the default value, ``""``. + :param lineSep: defines the line separator that should be used for writing. If None is + set, it uses the default value, ``\\n``. Maximum length is 1 character. >>> df.write.csv(os.path.join(tempfile.mkdtemp(), 'data')) """ @@ -932,7 +937,7 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No ignoreLeadingWhiteSpace=ignoreLeadingWhiteSpace, ignoreTrailingWhiteSpace=ignoreTrailingWhiteSpace, charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, - encoding=encoding, emptyValue=emptyValue) + encoding=encoding, emptyValue=emptyValue, lineSep=lineSep) self._jwrite.csv(path) @since(1.5) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 58ca7b83e5b2b..d92b0d5677e25 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -576,7 +576,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None, - enforceSchema=None, emptyValue=None, locale=None): + enforceSchema=None, emptyValue=None, locale=None, lineSep=None): r"""Loads a CSV file stream and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -675,6 +675,9 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non :param locale: sets a locale as language tag in IETF BCP 47 format. If None is set, it uses the default value, ``en-US``. For instance, ``locale`` is used while parsing dates and timestamps. + :param lineSep: defines the line separator that should be used for parsing. If None is + set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``. + Maximum length is 1 character. >>> csv_sdf = spark.readStream.csv(tempfile.mkdtemp(), schema = sdf_schema) >>> csv_sdf.isStreaming @@ -692,7 +695,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine, charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, enforceSchema=enforceSchema, - emptyValue=emptyValue, locale=locale) + emptyValue=emptyValue, locale=locale, lineSep=lineSep) if isinstance(path, basestring): return self._df(self._jreader.csv(path)) else: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala index 6bb50b42a369c..94bdb72d675d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala @@ -192,6 +192,20 @@ class CSVOptions( */ val emptyValueInWrite = emptyValue.getOrElse("\"\"") + /** + * A string between two consecutive JSON records. + */ + val lineSeparator: Option[String] = parameters.get("lineSep").map { sep => + require(sep.nonEmpty, "'lineSep' cannot be an empty string.") + require(sep.length == 1, "'lineSep' can contain only 1 character.") + sep + } + + val lineSeparatorInRead: Option[Array[Byte]] = lineSeparator.map { lineSep => + lineSep.getBytes(charset) + } + val lineSeparatorInWrite: Option[String] = lineSeparator + def asWriterSettings: CsvWriterSettings = { val writerSettings = new CsvWriterSettings() val format = writerSettings.getFormat @@ -200,6 +214,8 @@ class CSVOptions( format.setQuoteEscape(escape) charToEscapeQuoteEscaping.foreach(format.setCharToEscapeQuoteEscaping) format.setComment(comment) + lineSeparatorInWrite.foreach(format.setLineSeparator) + writerSettings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceFlagInWrite) writerSettings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceFlagInWrite) writerSettings.setNullValue(nullValue) @@ -216,8 +232,10 @@ class CSVOptions( format.setDelimiter(delimiter) format.setQuote(quote) format.setQuoteEscape(escape) + lineSeparator.foreach(format.setLineSeparator) charToEscapeQuoteEscaping.foreach(format.setCharToEscapeQuoteEscaping) format.setComment(comment) + settings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceInRead) settings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceInRead) settings.setReadInputOnSeparateThread(false) @@ -227,7 +245,10 @@ class CSVOptions( settings.setEmptyValue(emptyValueInRead) settings.setMaxCharsPerColumn(maxCharsPerColumn) settings.setUnescapedQuoteHandling(UnescapedQuoteHandling.STOP_AT_DELIMITER) - settings.setLineSeparatorDetectionEnabled(multiLine == true) + settings.setLineSeparatorDetectionEnabled(lineSeparatorInRead.isEmpty && multiLine) + lineSeparatorInRead.foreach { _ => + settings.setNormalizeLineEndingsWithinQuotes(!multiLine) + } settings } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index f08fd64acd9a1..da88598eed061 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -609,6 +609,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *