Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,15 @@ private[libsvm] class LibSVMOptions(@transient private val parameters: CaseInsen
case o => throw new IllegalArgumentException(s"Invalid value `$o` for parameter " +
s"`$VECTOR_TYPE`. Expected types are `sparse` and `dense`.")
}

val lineSeparator: String = parameters.getOrElse(LINE_SEPARATOR, "\n")
require(lineSeparator.nonEmpty, s"'$LINE_SEPARATOR' cannot be an empty string.")
}

private[libsvm] object LibSVMOptions {
val NUM_FEATURES = "numFeatures"
val VECTOR_TYPE = "vectorType"
val DENSE_VECTOR_TYPE = "dense"
val SPARSE_VECTOR_TYPE = "sparse"
val LINE_SEPARATOR = "lineSep"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use the full name.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This name came after sep in CSV which resembled R. Do you prefer separator and lineSeparator?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either lineDelimiter or lineSeparator looks fine.

In the future, we could also support field delimiters.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually meant sep here:

* <li>`sep` (default `,`): sets the single character as a separator for each
* field and value.</li>

and was thinking of matching the name ..

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for history, it was delimiter but renamed to sep.

}
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import org.apache.spark.util.SerializableConfiguration
private[libsvm] class LibSVMOutputWriter(
path: String,
dataSchema: StructType,
lineSeparator: String,
context: TaskAttemptContext)
extends OutputWriter {

Expand All @@ -57,7 +58,7 @@ private[libsvm] class LibSVMOutputWriter(
writer.write(s" ${i + 1}:$v")
}

writer.write('\n')
writer.write(lineSeparator)
}

override def close(): Unit = {
Expand Down Expand Up @@ -100,7 +101,7 @@ private[libsvm] class LibSVMFileFormat
"'numFeatures' option to avoid the extra scan.")

val paths = files.map(_.getPath.toUri.toString)
val parsed = MLUtils.parseLibSVMFile(sparkSession, paths)
val parsed = MLUtils.parseLibSVMFile(sparkSession, paths, libSVMOptions.lineSeparator)
MLUtils.computeNumFeatures(parsed)
}

Expand All @@ -120,12 +121,13 @@ private[libsvm] class LibSVMFileFormat
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory = {
verifySchema(dataSchema, true)
val libSVMOptions = new LibSVMOptions(options)
new OutputWriterFactory {
override def newInstance(
path: String,
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
new LibSVMOutputWriter(path, dataSchema, context)
new LibSVMOutputWriter(path, dataSchema, libSVMOptions.lineSeparator, context)
}

override def getFileExtension(context: TaskAttemptContext): String = {
Expand Down Expand Up @@ -153,7 +155,8 @@ private[libsvm] class LibSVMFileFormat
sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))

(file: PartitionedFile) => {
val linesReader = new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value)
val linesReader = new HadoopFileLinesReader(
file, libSVMOptions.lineSeparator, broadcastedHadoopConf.value.value)
Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close()))

val points = linesReader
Expand Down
11 changes: 8 additions & 3 deletions mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD}
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.datasources.text.TextFileFormat
import org.apache.spark.sql.execution.datasources.text.{TextFileFormat, TextOptions}
import org.apache.spark.sql.functions._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.random.BernoulliCellSampler
Expand Down Expand Up @@ -105,12 +105,17 @@ object MLUtils extends Logging {
}

private[spark] def parseLibSVMFile(
sparkSession: SparkSession, paths: Seq[String]): RDD[(Double, Array[Int], Array[Double])] = {
sparkSession: SparkSession,
paths: Seq[String],
lineSeparator: String): RDD[(Double, Array[Int], Array[Double])] = {
val textOptions = Map(TextOptions.LINE_SEPARATOR -> lineSeparator)

val lines = sparkSession.baseRelationToDataFrame(
DataSource.apply(
sparkSession,
paths = paths,
className = classOf[TextFileFormat].getName
className = classOf[TextFileFormat].getName,
options = textOptions
).resolveRelation(checkFilesExist = false))
.select("value")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,4 +184,62 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
spark.sql("DROP TABLE IF EXISTS libsvmTable")
}
}

def testLineSeparator(lineSep: String): Unit = {
test(s"SPARK-21289: Support line separator - lineSep: '$lineSep'") {
val data = Seq(
"1.0 1:1.0 3:2.0 5:3.0", "0.0", "0.0", "0.0 2:4.0 4:5.0 6:6.0").mkString(lineSep)
val dataWithTrailingLineSep = s"$data$lineSep"

Seq(data, dataWithTrailingLineSep).foreach { lines =>
val path0 = new File(tempDir.getCanonicalPath, "write0")
val path1 = new File(tempDir.getCanonicalPath, "write1")
try {
// Read
java.nio.file.Files.write(path0.toPath, lines.getBytes(StandardCharsets.UTF_8))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not import this ? java.nio.file.Files

Copy link
Member Author

@HyukjinKwon HyukjinKwon Dec 4, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To differentiate it from google's Files explicitly above. Not a big deal.

val df = spark.read
.option("lineSep", lineSep)
.format("libsvm")
.load(path0.getAbsolutePath)

assert(df.columns(0) == "label")
assert(df.columns(1) == "features")

val results = df.collect()

assert(results.map(_.getDouble(0)).toSet == Seq(1.0, 0.0, 0.0, 0.0).toSet)

val actual = results.map(_.getAs[SparseVector](1))
val expected = Seq(
Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0))),
Vectors.sparse(6, Nil),
Vectors.sparse(6, Nil),
Vectors.sparse(6, Seq((1, 4.0), (3, 5.0), (5, 6.0))))
assert(actual.toSet == expected.toSet)

// Write
df.coalesce(1)
.write.option("lineSep", lineSep).format("libsvm").save(path1.getAbsolutePath)
val partFile = Utils.recursiveList(path1).filter(f => f.getName.startsWith("part-")).head
val readBack = new String(
java.nio.file.Files.readAllBytes(partFile.toPath), StandardCharsets.UTF_8)
assert(readBack == dataWithTrailingLineSep)

// Roundtrip
val readBackDF = spark.read
.option("lineSep", lineSep)
.format("libsvm")
.load(path1.getAbsolutePath)
assert(df.collect().toSet == readBackDF.collect().toSet)
} finally {
Utils.deleteRecursively(path0)
Utils.deleteRecursively(path1)
}
}
}
}

Seq("123!!@", "^", "&@").foreach { lineSep =>
testLineSeparator(lineSep)
}
}
28 changes: 20 additions & 8 deletions python/pyspark/sql/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None,
allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None,
mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None,
multiLine=None, allowUnquotedControlChars=None):
multiLine=None, allowUnquotedControlChars=None, lineSep=None):
"""
Loads JSON files and returns the results as a :class:`DataFrame`.

Expand Down Expand Up @@ -237,6 +237,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
:param allowUnquotedControlChars: allows JSON Strings to contain unquoted control
characters (ASCII characters with value less than 32,
including tab and line feed characters) or not.
:param lineSep: defines the line separator that should be used for parsing. If None is
set, it uses ``\\n`` by default, covering ``\\r``, ``\\r\\n`` and ``\\n``.

>>> df1 = spark.read.json('python/test_support/sql/people.json')
>>> df1.dtypes
Expand All @@ -254,7 +256,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter,
mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat,
timestampFormat=timestampFormat, multiLine=multiLine,
allowUnquotedControlChars=allowUnquotedControlChars)
allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep)
if isinstance(path, basestring):
path = [path]
if type(path) == list:
Expand Down Expand Up @@ -304,7 +306,7 @@ def parquet(self, *paths):

@ignore_unicode_prefix
@since(1.6)
def text(self, paths):
def text(self, paths, lineSep=None):
"""
Loads text files and returns a :class:`DataFrame` whose schema starts with a
string column named "value", and followed by partitioned columns if there
Expand All @@ -313,11 +315,14 @@ def text(self, paths):
Each line in the text file is a new row in the resulting DataFrame.

:param paths: string, or list of strings, for input path(s).
:param lineSep: defines the line separator that should be used for parsing. If None is
set, it uses ``\\n`` by default, covering ``\\r``, ``\\r\\n`` and ``\\n``.

>>> df = spark.read.text('python/test_support/sql/text-test.txt')
>>> df.collect()
[Row(value=u'hello'), Row(value=u'this')]
"""
self._set_opts(lineSep=lineSep)
if isinstance(paths, basestring):
paths = [paths]
return self._df(self._jreader.text(self._spark._sc._jvm.PythonUtils.toSeq(paths)))
Expand All @@ -342,7 +347,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
:param sep: sets the single character as a separator for each field and value.
If None is set, it uses the default value, ``,``.
:param encoding: decodes the CSV files by the given encoding type. If None is set,
it uses the default value, ``UTF-8``.
it uses the default value, ``UTF-8``. Note that, currently this option
does not support non-ascii compatible encodings.
:param quote: sets the single character used for escaping quoted values where the
separator can be part of the value. If None is set, it uses the default
value, ``"``. If you would like to turn off quotations, you need to set an
Expand Down Expand Up @@ -730,7 +736,8 @@ def saveAsTable(self, name, format=None, mode=None, partitionBy=None, **options)
self._jwrite.saveAsTable(name)

@since(1.4)
def json(self, path, mode=None, compression=None, dateFormat=None, timestampFormat=None):
def json(self, path, mode=None, compression=None, dateFormat=None, timestampFormat=None,
lineSep=None):
"""Saves the content of the :class:`DataFrame` in JSON format
(`JSON Lines text format or newline-delimited JSON <http://jsonlines.org/>`_) at the
specified path.
Expand All @@ -753,12 +760,15 @@ def json(self, path, mode=None, compression=None, dateFormat=None, timestampForm
formats follow the formats at ``java.text.SimpleDateFormat``.
This applies to timestamp type. If None is set, it uses the
default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``.
:param lineSep: defines the line separator that should be used for writing. If None is
set, it uses the default value, ``\\n``.

>>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data'))
"""
self.mode(mode)
self._set_opts(
compression=compression, dateFormat=dateFormat, timestampFormat=timestampFormat)
compression=compression, dateFormat=dateFormat, timestampFormat=timestampFormat,
lineSep=lineSep)
self._jwrite.json(path)

@since(1.4)
Expand Down Expand Up @@ -788,18 +798,20 @@ def parquet(self, path, mode=None, partitionBy=None, compression=None):
self._jwrite.parquet(path)

@since(1.6)
def text(self, path, compression=None):
def text(self, path, compression=None, lineSep=None):
"""Saves the content of the DataFrame in a text file at the specified path.

:param path: the path in any Hadoop supported file system
:param compression: compression codec to use when saving to file. This can be one of the
known case-insensitive shorten names (none, bzip2, gzip, lz4,
snappy and deflate).
:param lineSep: defines the line separator that should be used for writing. If None is
set, it uses the default value, ``\\n``.

The DataFrame must have only one column that is of string type.
Each row becomes a new line in the output file.
"""
self._set_opts(compression=compression)
self._set_opts(compression=compression, lineSep=lineSep)
self._jwrite.text(path)

@since(2.0)
Expand Down
14 changes: 10 additions & 4 deletions python/pyspark/sql/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None,
allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None,
mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None,
multiLine=None, allowUnquotedControlChars=None):
multiLine=None, allowUnquotedControlChars=None, lineSep=None):
"""
Loads a JSON file stream and returns the results as a :class:`DataFrame`.

Expand Down Expand Up @@ -470,6 +470,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
:param allowUnquotedControlChars: allows JSON Strings to contain unquoted control
characters (ASCII characters with value less than 32,
including tab and line feed characters) or not.
:param lineSep: defines the line separator that should be used for parsing. If None is
set, it uses ``\\n`` by default, covering ``\\r``, ``\\r\\n`` and ``\\n``.

>>> json_sdf = spark.readStream.json(tempfile.mkdtemp(), schema = sdf_schema)
>>> json_sdf.isStreaming
Expand All @@ -484,7 +486,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter,
mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat,
timestampFormat=timestampFormat, multiLine=multiLine,
allowUnquotedControlChars=allowUnquotedControlChars)
allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep)
if isinstance(path, basestring):
return self._df(self._jreader.json(path))
else:
Expand Down Expand Up @@ -514,7 +516,7 @@ def parquet(self, path):

@ignore_unicode_prefix
@since(2.0)
def text(self, path):
def text(self, path, lineSep=None):
"""
Loads a text file stream and returns a :class:`DataFrame` whose schema starts with a
string column named "value", and followed by partitioned columns if there
Expand All @@ -525,13 +527,16 @@ def text(self, path):
.. note:: Evolving.

:param paths: string, or list of strings, for input path(s).
:param lineSep: defines the line separator that should be used for parsing. If None is
set, it uses ``\\n`` by default, covering ``\\r``, ``\\r\\n`` and ``\\n``.

>>> text_sdf = spark.readStream.text(tempfile.mkdtemp())
>>> text_sdf.isStreaming
True
>>> "value" in str(text_sdf.schema)
True
"""
self._set_opts(lineSep=lineSep)
if isinstance(path, basestring):
return self._df(self._jreader.text(path))
else:
Expand All @@ -558,7 +563,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
:param sep: sets the single character as a separator for each field and value.
If None is set, it uses the default value, ``,``.
:param encoding: decodes the CSV files by the given encoding type. If None is set,
it uses the default value, ``UTF-8``.
it uses the default value, ``UTF-8``. Note that, currently this option
does not support non-ascii compatible encodings.
:param quote: sets the single character used for escaping quoted values where the
separator can be part of the value. If None is set, it uses the default
value, ``"``. If you would like to turn off quotations, you need to set an
Expand Down
41 changes: 40 additions & 1 deletion python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,12 +508,51 @@ def test_non_existed_udaf(self):
self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udaf",
lambda: spark.udf.registerJavaUDAF("udaf1", "non_existed_udaf"))

def test_multiLine_json(self):
def test_linesep_text(self):
df = self.spark.read.text("python/test_support/sql/ages_newlines.csv", lineSep=",")
expected = [Row(value=u'Joe'), Row(value=u'20'), Row(value=u'"Hi'),
Row(value=u'\nI am Jeo"\nTom'), Row(value=u'30'),
Row(value=u'"My name is Tom"\nHyukjin'), Row(value=u'25'),
Row(value=u'"I am Hyukjin\n\nI love Spark!"\n')]
self.assertEqual(df.collect(), expected)

tpath = tempfile.mkdtemp()
shutil.rmtree(tpath)
try:
df.write.text(tpath, lineSep="!")
expected = [Row(value=u'Joe!20!"Hi!'), Row(value=u'I am Jeo"'),
Row(value=u'Tom!30!"My name is Tom"'),
Row(value=u'Hyukjin!25!"I am Hyukjin'),
Row(value=u''), Row(value=u'I love Spark!"'),
Row(value=u'!')]
readback = self.spark.read.text(tpath)
self.assertEqual(readback.collect(), expected)
finally:
shutil.rmtree(tpath)

def test_multiline_json(self):
people1 = self.spark.read.json("python/test_support/sql/people.json")
people_array = self.spark.read.json("python/test_support/sql/people_array.json",
multiLine=True)
self.assertEqual(people1.collect(), people_array.collect())

def test_linesep_json(self):
df = self.spark.read.json("python/test_support/sql/people.json", lineSep=",")
expected = [Row(_corrupt_record=None, name=u'Michael'),
Row(_corrupt_record=u' "age":30}\n{"name":"Justin"', name=None),
Row(_corrupt_record=u' "age":19}\n', name=None)]
self.assertEqual(df.collect(), expected)

tpath = tempfile.mkdtemp()
shutil.rmtree(tpath)
try:
df = self.spark.read.json("python/test_support/sql/people.json")
df.write.json(tpath, lineSep="!!")
readback = self.spark.read.json(tpath, lineSep="!!")
self.assertEqual(readback.collect(), df.collect())
finally:
shutil.rmtree(tpath)

def test_multiline_csv(self):
ages_newlines = self.spark.read.csv(
"python/test_support/sql/ages_newlines.csv", multiLine=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ private[sql] class JSONOptions(

val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false)

val lineSeparator: String = parameters.getOrElse("lineSep", "\n")
require(lineSeparator.nonEmpty, "'lineSep' cannot be an empty string.")

/** Sets config options on a Jackson [[JsonFactory]]. */
def setJacksonOptions(factory: JsonFactory): Unit = {
factory.configure(JsonParser.Feature.ALLOW_COMMENTS, allowComments)
Expand Down
Loading