Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 35 additions & 5 deletions python/pyspark/sql/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,12 @@ def option(self, key, value):
in the JSON/CSV datasources or partition values.
If it isn't set, it uses the default value, session local timezone.
"""
self._jreader = self._jreader.option(key, to_str(value))
if isinstance(value, (list, tuple)):
gateway = self._spark._sc._gateway
jvalues = utils.toJArray(gateway, gateway.jvm.java.lang.String, value)
self._jreader = self._jreader.option(key, jvalues)
else:
self._jreader = self._jreader.option(key, to_str(value))
return self

@since(1.4)
Expand All @@ -134,7 +139,14 @@ def options(self, **options):
If it isn't set, it uses the default value, session local timezone.
"""
for k in options:
self._jreader = self._jreader.option(k, to_str(options[k]))
self.option(k, options[k])
return self

@since(2.3)
def unsetOption(self, key):
"""Un-sets the option given to the key for the underlying data source.
"""
self._jreader = self._jreader.unsetOption(key)
return self

@since(1.4)
Expand Down Expand Up @@ -322,6 +334,7 @@ def text(self, paths):
paths = [paths]
return self._df(self._jreader.text(self._spark._sc._jvm.PythonUtils.toSeq(paths)))

@ignore_unicode_prefix
@since(2.0)
def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=None,
comment=None, header=None, inferSchema=None, ignoreLeadingWhiteSpace=None,
Expand Down Expand Up @@ -362,7 +375,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
uses the default value, ``false``.
:param nullValue: sets the string representation of a null value. If None is set, it uses
the default value, empty string. Since 2.0.1, this ``nullValue`` param
applies to all supported types including the string type.
applies to all supported types including the string type. A list or tuple
of strings to represent null values can be set for this option.
:param nanValue: sets the string representation of a non-number value. If None is set, it
uses the default value, ``NaN``.
:param positiveInf: sets the string representation of a positive infinity value. If None
Expand Down Expand Up @@ -408,6 +422,10 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
>>> df = spark.read.csv('python/test_support/sql/ages.csv')
>>> df.dtypes
[('_c0', 'string'), ('_c1', 'string')]

>>> df = spark.read.csv('python/test_support/sql/ages.csv', nullValue=['Tom', 'Joe'])
>>> df.collect()
[Row(_c0=None, _c1=u'20'), Row(_c0=None, _c1=u'30'), Row(_c0=u'Hyukjin', _c1=u'25')]
"""
self._set_opts(
schema=schema, sep=sep, encoding=encoding, quote=quote, escape=escape, comment=comment,
Expand Down Expand Up @@ -544,7 +562,12 @@ def option(self, key, value):
timestamps in the JSON/CSV datasources or partition values.
If it isn't set, it uses the default value, session local timezone.
"""
self._jwrite = self._jwrite.option(key, to_str(value))
if isinstance(value, (list, tuple)):
gateway = self._spark._sc._gateway
jvalues = utils.toJArray(gateway, gateway.jvm.java.lang.String, value)
self._jwrite = self._jwrite.option(key, jvalues)
else:
self._jwrite = self._jwrite.option(key, to_str(value))
return self

@since(1.4)
Expand All @@ -557,7 +580,14 @@ def options(self, **options):
If it isn't set, it uses the default value, session local timezone.
"""
for k in options:
self._jwrite = self._jwrite.option(k, to_str(options[k]))
self.option(k, options[k])
return self

@since(2.3)
def unsetOption(self, key):
"""Un-sets the option given to the key for the underlying data source.
"""
self._jwrite = self._jwrite.unsetOption(key)
return self

@since(1.4)
Expand Down
39 changes: 34 additions & 5 deletions python/pyspark/sql/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,12 @@ def option(self, key, value):

>>> s = spark.readStream.option("x", 1)
"""
self._jreader = self._jreader.option(key, to_str(value))
if isinstance(value, (list, tuple)):
gateway = self._spark._sc._gateway
jvalues = utils.toJArray(gateway, gateway.jvm.java.lang.String, value)
self._jreader = self._jreader.option(key, jvalues)
else:
self._jreader = self._jreader.option(key, to_str(value))
return self

@since(2.0)
Expand All @@ -366,7 +371,16 @@ def options(self, **options):
>>> s = spark.readStream.options(x="1", y=2)
"""
for k in options:
self._jreader = self._jreader.option(k, to_str(options[k]))
self.option(k, options[k])
return self

@since(2.3)
def unsetOption(self, key):
"""Un-sets the option given to the key for the underlying data source.

.. note:: Evolving.
"""
self._jreader = self._jreader.unsetOption(key)
return self

@since(2.0)
Expand Down Expand Up @@ -579,7 +593,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
uses the default value, ``false``.
:param nullValue: sets the string representation of a null value. If None is set, it uses
the default value, empty string. Since 2.0.1, this ``nullValue`` param
applies to all supported types including the string type.
applies to all supported types including the string type. A list or tuple
of strings to represent null values can be set for this option.
:param nanValue: sets the string representation of a non-number value. If None is set, it
uses the default value, ``NaN``.
:param positiveInf: sets the string representation of a positive infinity value. If None
Expand Down Expand Up @@ -710,7 +725,12 @@ def option(self, key, value):

.. note:: Evolving.
"""
self._jwrite = self._jwrite.option(key, to_str(value))
if isinstance(value, (list, tuple)):
gateway = self._spark._sc._gateway
jvalues = utils.toJArray(gateway, gateway.jvm.java.lang.String, value)
self._jwrite = self._jwrite.option(key, jvalues)
else:
self._jwrite = self._jwrite.option(key, to_str(value))
return self

@since(2.0)
Expand All @@ -725,7 +745,16 @@ def options(self, **options):
.. note:: Evolving.
"""
for k in options:
self._jwrite = self._jwrite.option(k, to_str(options[k]))
self.option(k, options[k])
return self

@since(2.3)
def unsetOption(self, key):
"""Un-sets the option given to the key for the underlying data source.

.. note:: Evolving.
"""
self._jwrite = self._jwrite.unsetOption(key)
return self

@since(2.0)
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1369,7 +1369,7 @@ def test_save_and_load(self):
self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName)

csvpath = os.path.join(tempfile.mkdtemp(), 'data')
df.write.option('quote', None).format('csv').save(csvpath)
df.write.unsetOption('quote').format('csv').save(csvpath)

shutil.rmtree(tmpPath)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,12 +297,21 @@ tablePropertyKey
;

tablePropertyValue
: propertyValue
| propertyArrayValue
;

propertyValue
: INTEGER_VALUE
| DECIMAL_VALUE
| booleanValue
| STRING
;

propertyArrayValue
: '[' propertyValue (',' propertyValue)* ']'
;

constantList
: '(' constant (',' constant)* ')'
;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ import java.util.{Locale, Properties}

import scala.collection.JavaConverters._

import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods.{compact, render}

import org.apache.spark.Partition
import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.api.java.JavaRDD
Expand Down Expand Up @@ -116,6 +119,22 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
*/
def option(key: String, value: Double): DataFrameReader = option(key, value.toString)

/**
* (Scala-specific) Adds an input option for the underlying data source.
*
* @since 2.3.0
*/
def option(key: String, value: Seq[String]): DataFrameReader = {
option(key, compact(render(value)))
}

/**
* Adds an input option for the underlying data source.
*
* @since 2.3.0
*/
def option(key: String, value: Array[String]): DataFrameReader = option(key, value.toSeq)

/**
* (Scala-specific) Adds input options for the underlying data source.
*
Expand Down Expand Up @@ -148,6 +167,16 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
this
}

/**
* Un-sets an input option for the underlying data source.
*
* @since 2.3.0
*/
def unsetOption(key: String): DataFrameReader = {
this.extraOptions.remove(key)
this
}

/**
* Loads input in as a `DataFrame`, for data sources that don't require a path (e.g. external
* key-value stores).
Expand Down Expand Up @@ -502,7 +531,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* <li>`ignoreTrailingWhiteSpace` (default `false`): a flag indicating whether or not trailing
* whitespaces from values being read should be skipped.</li>
* <li>`nullValue` (default empty string): sets the string representation of a null value. Since
* 2.0.1, this applies to all supported types including the string type.</li>
* 2.0.1, this applies to all supported types including the string type.
* An array of strings to represent null values can be set for this option. For example,
* `spark.read.format("csv").option("nullValue", Seq("", "null"))`.</li>
* <li>`nanValue` (default `NaN`): sets the string representation of a non-number" value.</li>
* <li>`positiveInf` (default `Inf`): sets the string representation of a positive infinity
* value.</li>
Expand Down
29 changes: 29 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ import java.util.{Locale, Properties}

import scala.collection.JavaConverters._

import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods.{compact, render}

import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedRelation}
Expand Down Expand Up @@ -125,6 +128,22 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
*/
def option(key: String, value: Double): DataFrameWriter[T] = option(key, value.toString)

/**
* (Scala-specific) Adds an output option for the underlying data source.
*
* @since 2.3.0
*/
def option(key: String, value: Seq[String]): DataFrameWriter[T] = {
option(key, compact(render(value)))
}

/**
* Adds an output option for the underlying data source.
*
* @since 2.3.0
*/
def option(key: String, value: Array[String]): DataFrameWriter[T] = option(key, value.toSeq)

/**
* (Scala-specific) Adds output options for the underlying data source.
*
Expand Down Expand Up @@ -157,6 +176,16 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
this
}

/**
* Un-sets an output option for the underlying data source.
*
* @since 2.3.0
*/
def unsetOption(key: String): DataFrameWriter[T] = {
this.extraOptions.remove(key)
this
}

/**
* Partitions the output by the given columns on the file system. If specified, the output is
* laid out on the file system similar to Hive's partitioning scheme. As an example, when we
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import scala.collection.JavaConverters._

import org.antlr.v4.runtime.{ParserRuleContext, Token}
import org.antlr.v4.runtime.tree.TerminalNode
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods.{compact, render}

import org.apache.spark.sql.SaveMode
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
Expand Down Expand Up @@ -587,6 +589,21 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
* the property value based on whether its a string, integer, boolean or decimal literal.
*/
override def visitTablePropertyValue(value: TablePropertyValueContext): String = {
if (value == null) {
null
} else if (value.propertyValue != null) {
visitPropertyValue(value.propertyValue)
} else if (value.propertyArrayValue != null) {
val values = value.propertyArrayValue.propertyValue.asScala.map { v =>
visitPropertyValue(v)
}
compact(render(values))
} else {
value.getText
}
}

override def visitPropertyValue(value: PropertyValueContext): String = {
if (value == null) {
null
} else if (value.STRING != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ private[csv] object CSVInferSchema {
* point checking if it is an Int, as the final type must be Double or higher.
*/
def inferField(typeSoFar: DataType, field: String, options: CSVOptions): DataType = {
if (field == null || field.isEmpty || field == options.nullValue) {
if (field == null || field.isEmpty || options.nullValue.contains(field)) {
typeSoFar
} else {
typeSoFar match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,12 @@ package org.apache.spark.sql.execution.datasources.csv
import java.nio.charset.StandardCharsets
import java.util.{Locale, TimeZone}

import scala.util.Try

import com.univocity.parsers.csv.{CsvParserSettings, CsvWriterSettings, UnescapedQuoteHandling}
import org.apache.commons.lang3.time.FastDateFormat
import org.json4s._
import org.json4s.jackson.JsonMethods.parse

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.util._
Expand Down Expand Up @@ -104,7 +108,17 @@ class CSVOptions(
val columnNameOfCorruptRecord =
parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord)

val nullValue = parameters.getOrElse("nullValue", "")
val nullValue: Array[String] = parameters.get("nullValue").map { str =>
Try {
implicit val formats = DefaultFormats
val arr = parse(str).extract[Array[String]]
// If the input is just a string not formatted as a json array, it can be an empty array.
// In this case throws an exception so that the string value is used as is for backwards
// compatibility.
require(arr.nonEmpty)
arr
}.getOrElse(Array(str))
}.getOrElse(Array(""))

val nanValue = parameters.getOrElse("nanValue", "NaN")

Expand Down Expand Up @@ -151,8 +165,8 @@ class CSVOptions(
format.setComment(comment)
writerSettings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceFlagInWrite)
writerSettings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceFlagInWrite)
writerSettings.setNullValue(nullValue)
writerSettings.setEmptyValue(nullValue)
writerSettings.setNullValue(nullValue(0))
writerSettings.setEmptyValue(nullValue(0))
writerSettings.setSkipEmptyLines(true)
writerSettings.setQuoteAllFields(quoteAll)
writerSettings.setQuoteEscapingEnabled(escapeQuotes)
Expand All @@ -171,7 +185,7 @@ class CSVOptions(
settings.setReadInputOnSeparateThread(false)
settings.setInputBufferSize(inputBufferSize)
settings.setMaxColumns(maxColumns)
settings.setNullValue(nullValue)
settings.setNullValue(nullValue(0))
settings.setMaxCharsPerColumn(maxCharsPerColumn)
settings.setUnescapedQuoteHandling(UnescapedQuoteHandling.STOP_AT_DELIMITER)
settings
Expand Down
Loading