diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index cb847a0420311..887930e9a44c4 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -121,7 +121,12 @@ def option(self, key, value):
in the JSON/CSV datasources or partition values.
If it isn't set, it uses the default value, session local timezone.
"""
- self._jreader = self._jreader.option(key, to_str(value))
+ if isinstance(value, (list, tuple)):
+ gateway = self._spark._sc._gateway
+ jvalues = utils.toJArray(gateway, gateway.jvm.java.lang.String, value)
+ self._jreader = self._jreader.option(key, jvalues)
+ else:
+ self._jreader = self._jreader.option(key, to_str(value))
return self
@since(1.4)
@@ -134,7 +139,14 @@ def options(self, **options):
If it isn't set, it uses the default value, session local timezone.
"""
for k in options:
- self._jreader = self._jreader.option(k, to_str(options[k]))
+ self.option(k, options[k])
+ return self
+
+ @since(2.3)
+ def unsetOption(self, key):
+ """Un-sets the option given to the key for the underlying data source.
+ """
+ self._jreader = self._jreader.unsetOption(key)
return self
@since(1.4)
@@ -322,6 +334,7 @@ def text(self, paths):
paths = [paths]
return self._df(self._jreader.text(self._spark._sc._jvm.PythonUtils.toSeq(paths)))
+ @ignore_unicode_prefix
@since(2.0)
def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=None,
comment=None, header=None, inferSchema=None, ignoreLeadingWhiteSpace=None,
@@ -362,7 +375,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
uses the default value, ``false``.
:param nullValue: sets the string representation of a null value. If None is set, it uses
the default value, empty string. Since 2.0.1, this ``nullValue`` param
- applies to all supported types including the string type.
+ applies to all supported types including the string type. A list or tuple
+ of strings to represent null values can be set for this option.
:param nanValue: sets the string representation of a non-number value. If None is set, it
uses the default value, ``NaN``.
:param positiveInf: sets the string representation of a positive infinity value. If None
@@ -408,6 +422,10 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
>>> df = spark.read.csv('python/test_support/sql/ages.csv')
>>> df.dtypes
[('_c0', 'string'), ('_c1', 'string')]
+
+ >>> df = spark.read.csv('python/test_support/sql/ages.csv', nullValue=['Tom', 'Joe'])
+ >>> df.collect()
+ [Row(_c0=None, _c1=u'20'), Row(_c0=None, _c1=u'30'), Row(_c0=u'Hyukjin', _c1=u'25')]
"""
self._set_opts(
schema=schema, sep=sep, encoding=encoding, quote=quote, escape=escape, comment=comment,
@@ -544,7 +562,12 @@ def option(self, key, value):
timestamps in the JSON/CSV datasources or partition values.
If it isn't set, it uses the default value, session local timezone.
"""
- self._jwrite = self._jwrite.option(key, to_str(value))
+ if isinstance(value, (list, tuple)):
+ gateway = self._spark._sc._gateway
+ jvalues = utils.toJArray(gateway, gateway.jvm.java.lang.String, value)
+ self._jwrite = self._jwrite.option(key, jvalues)
+ else:
+ self._jwrite = self._jwrite.option(key, to_str(value))
return self
@since(1.4)
@@ -557,7 +580,14 @@ def options(self, **options):
If it isn't set, it uses the default value, session local timezone.
"""
for k in options:
- self._jwrite = self._jwrite.option(k, to_str(options[k]))
+ self.option(k, options[k])
+ return self
+
+ @since(2.3)
+ def unsetOption(self, key):
+ """Un-sets the option given to the key for the underlying data source.
+ """
+ self._jwrite = self._jwrite.unsetOption(key)
return self
@since(1.4)
diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py
index 0cf702143c773..97d83efd439f6 100644
--- a/python/pyspark/sql/streaming.py
+++ b/python/pyspark/sql/streaming.py
@@ -349,7 +349,12 @@ def option(self, key, value):
>>> s = spark.readStream.option("x", 1)
"""
- self._jreader = self._jreader.option(key, to_str(value))
+ if isinstance(value, (list, tuple)):
+ gateway = self._spark._sc._gateway
+ jvalues = utils.toJArray(gateway, gateway.jvm.java.lang.String, value)
+ self._jreader = self._jreader.option(key, jvalues)
+ else:
+ self._jreader = self._jreader.option(key, to_str(value))
return self
@since(2.0)
@@ -366,7 +371,16 @@ def options(self, **options):
>>> s = spark.readStream.options(x="1", y=2)
"""
for k in options:
- self._jreader = self._jreader.option(k, to_str(options[k]))
+ self.option(k, options[k])
+ return self
+
+ @since(2.3)
+ def unsetOption(self, key):
+ """Un-sets the option given to the key for the underlying data source.
+
+ .. note:: Evolving.
+ """
+ self._jreader = self._jreader.unsetOption(key)
return self
@since(2.0)
@@ -579,7 +593,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
uses the default value, ``false``.
:param nullValue: sets the string representation of a null value. If None is set, it uses
the default value, empty string. Since 2.0.1, this ``nullValue`` param
- applies to all supported types including the string type.
+ applies to all supported types including the string type. A list or tuple
+ of strings to represent null values can be set for this option.
:param nanValue: sets the string representation of a non-number value. If None is set, it
uses the default value, ``NaN``.
:param positiveInf: sets the string representation of a positive infinity value. If None
@@ -710,7 +725,12 @@ def option(self, key, value):
.. note:: Evolving.
"""
- self._jwrite = self._jwrite.option(key, to_str(value))
+ if isinstance(value, (list, tuple)):
+ gateway = self._spark._sc._gateway
+ jvalues = utils.toJArray(gateway, gateway.jvm.java.lang.String, value)
+ self._jwrite = self._jwrite.option(key, jvalues)
+ else:
+ self._jwrite = self._jwrite.option(key, to_str(value))
return self
@since(2.0)
@@ -725,7 +745,16 @@ def options(self, **options):
.. note:: Evolving.
"""
for k in options:
- self._jwrite = self._jwrite.option(k, to_str(options[k]))
+ self.option(k, options[k])
+ return self
+
+ @since(2.3)
+ def unsetOption(self, key):
+ """Un-sets the option given to the key for the underlying data source.
+
+ .. note:: Evolving.
+ """
+ self._jwrite = self._jwrite.unsetOption(key)
return self
@since(2.0)
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 3d87ccfc03ddd..cd2d845e244cf 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -1369,7 +1369,7 @@ def test_save_and_load(self):
self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
csvpath = os.path.join(tempfile.mkdtemp(), 'data')
- df.write.option('quote', None).format('csv').save(csvpath)
+ df.write.unsetOption('quote').format('csv').save(csvpath)
shutil.rmtree(tmpPath)
diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
index f741dcfbf2002..ec15049f3b485 100644
--- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
+++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
@@ -297,12 +297,21 @@ tablePropertyKey
;
tablePropertyValue
+ : propertyValue
+ | propertyArrayValue
+ ;
+
+propertyValue
: INTEGER_VALUE
| DECIMAL_VALUE
| booleanValue
| STRING
;
+propertyArrayValue
+ : '[' propertyValue (',' propertyValue)* ']'
+ ;
+
constantList
: '(' constant (',' constant)* ')'
;
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 4f375e59c34d4..f463b832a0003 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -21,6 +21,9 @@ import java.util.{Locale, Properties}
import scala.collection.JavaConverters._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods.{compact, render}
+
import org.apache.spark.Partition
import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.api.java.JavaRDD
@@ -116,6 +119,22 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
*/
def option(key: String, value: Double): DataFrameReader = option(key, value.toString)
+ /**
+ * (Scala-specific) Adds an input option for the underlying data source.
+ *
+ * @since 2.3.0
+ */
+ def option(key: String, value: Seq[String]): DataFrameReader = {
+ option(key, compact(render(value)))
+ }
+
+ /**
+ * Adds an input option for the underlying data source.
+ *
+ * @since 2.3.0
+ */
+ def option(key: String, value: Array[String]): DataFrameReader = option(key, value.toSeq)
+
/**
* (Scala-specific) Adds input options for the underlying data source.
*
@@ -148,6 +167,16 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
this
}
+ /**
+ * Un-sets an input option for the underlying data source.
+ *
+ * @since 2.3.0
+ */
+ def unsetOption(key: String): DataFrameReader = {
+ this.extraOptions.remove(key)
+ this
+ }
+
/**
* Loads input in as a `DataFrame`, for data sources that don't require a path (e.g. external
* key-value stores).
@@ -502,7 +531,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
*
`ignoreTrailingWhiteSpace` (default `false`): a flag indicating whether or not trailing
* whitespaces from values being read should be skipped.
* `nullValue` (default empty string): sets the string representation of a null value. Since
- * 2.0.1, this applies to all supported types including the string type.
+ * 2.0.1, this applies to all supported types including the string type.
+ * An array of strings to represent null values can be set for this option. For example,
+ * `spark.read.format("csv").option("nullValue", Seq("", "null"))`.
* `nanValue` (default `NaN`): sets the string representation of a non-number" value.
* `positiveInf` (default `Inf`): sets the string representation of a positive infinity
* value.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index 07347d2748544..2329245ed2518 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -21,6 +21,9 @@ import java.util.{Locale, Properties}
import scala.collection.JavaConverters._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods.{compact, render}
+
import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedRelation}
@@ -125,6 +128,22 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
*/
def option(key: String, value: Double): DataFrameWriter[T] = option(key, value.toString)
+ /**
+ * (Scala-specific) Adds an output option for the underlying data source.
+ *
+ * @since 2.3.0
+ */
+ def option(key: String, value: Seq[String]): DataFrameWriter[T] = {
+ option(key, compact(render(value)))
+ }
+
+ /**
+ * Adds an output option for the underlying data source.
+ *
+ * @since 2.3.0
+ */
+ def option(key: String, value: Array[String]): DataFrameWriter[T] = option(key, value.toSeq)
+
/**
* (Scala-specific) Adds output options for the underlying data source.
*
@@ -157,6 +176,16 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
this
}
+ /**
+ * Un-sets an output option for the underlying data source.
+ *
+ * @since 2.3.0
+ */
+ def unsetOption(key: String): DataFrameWriter[T] = {
+ this.extraOptions.remove(key)
+ this
+ }
+
/**
* Partitions the output by the given columns on the file system. If specified, the output is
* laid out on the file system similar to Hive's partitioning scheme. As an example, when we
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
index d3f6ab5654689..25469c8bbd4d3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
@@ -23,6 +23,8 @@ import scala.collection.JavaConverters._
import org.antlr.v4.runtime.{ParserRuleContext, Token}
import org.antlr.v4.runtime.tree.TerminalNode
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods.{compact, render}
import org.apache.spark.sql.SaveMode
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
@@ -587,6 +589,21 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
* the property value based on whether its a string, integer, boolean or decimal literal.
*/
override def visitTablePropertyValue(value: TablePropertyValueContext): String = {
+ if (value == null) {
+ null
+ } else if (value.propertyValue != null) {
+ visitPropertyValue(value.propertyValue)
+ } else if (value.propertyArrayValue != null) {
+ val values = value.propertyArrayValue.propertyValue.asScala.map { v =>
+ visitPropertyValue(v)
+ }
+ compact(render(values))
+ } else {
+ value.getText
+ }
+ }
+
+ override def visitPropertyValue(value: PropertyValueContext): String = {
if (value == null) {
null
} else if (value.STRING != null) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
index b64d71bb4eef2..b611e3b631fdd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
@@ -79,7 +79,7 @@ private[csv] object CSVInferSchema {
* point checking if it is an Int, as the final type must be Double or higher.
*/
def inferField(typeSoFar: DataType, field: String, options: CSVOptions): DataType = {
- if (field == null || field.isEmpty || field == options.nullValue) {
+ if (field == null || field.isEmpty || options.nullValue.contains(field)) {
typeSoFar
} else {
typeSoFar match {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
index a13a5a34b4a84..1b5b59843cacd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
@@ -20,8 +20,12 @@ package org.apache.spark.sql.execution.datasources.csv
import java.nio.charset.StandardCharsets
import java.util.{Locale, TimeZone}
+import scala.util.Try
+
import com.univocity.parsers.csv.{CsvParserSettings, CsvWriterSettings, UnescapedQuoteHandling}
import org.apache.commons.lang3.time.FastDateFormat
+import org.json4s._
+import org.json4s.jackson.JsonMethods.parse
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.util._
@@ -104,7 +108,17 @@ class CSVOptions(
val columnNameOfCorruptRecord =
parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord)
- val nullValue = parameters.getOrElse("nullValue", "")
+ val nullValue: Array[String] = parameters.get("nullValue").map { str =>
+ Try {
+ implicit val formats = DefaultFormats
+ val arr = parse(str).extract[Array[String]]
+ // If the input is just a string not formatted as a json array, it can be an empty array.
+ // In this case throws an exception so that the string value is used as is for backwards
+ // compatibility.
+ require(arr.nonEmpty)
+ arr
+ }.getOrElse(Array(str))
+ }.getOrElse(Array(""))
val nanValue = parameters.getOrElse("nanValue", "NaN")
@@ -151,8 +165,8 @@ class CSVOptions(
format.setComment(comment)
writerSettings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceFlagInWrite)
writerSettings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceFlagInWrite)
- writerSettings.setNullValue(nullValue)
- writerSettings.setEmptyValue(nullValue)
+ writerSettings.setNullValue(nullValue(0))
+ writerSettings.setEmptyValue(nullValue(0))
writerSettings.setSkipEmptyLines(true)
writerSettings.setQuoteAllFields(quoteAll)
writerSettings.setQuoteEscapingEnabled(escapeQuotes)
@@ -171,7 +185,7 @@ class CSVOptions(
settings.setReadInputOnSeparateThread(false)
settings.setInputBufferSize(inputBufferSize)
settings.setMaxColumns(maxColumns)
- settings.setNullValue(nullValue)
+ settings.setNullValue(nullValue(0))
settings.setMaxCharsPerColumn(maxCharsPerColumn)
settings.setUnescapedQuoteHandling(UnescapedQuoteHandling.STOP_AT_DELIMITER)
settings
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityGenerator.scala
index 4082a0df8ba75..dbdaae1162109 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityGenerator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityGenerator.scala
@@ -65,7 +65,7 @@ private[csv] class UnivocityGenerator(
if (!row.isNullAt(i)) {
values(i) = valueConverters(i).apply(row, i)
} else {
- values(i) = options.nullValue
+ values(i) = options.nullValue(0)
}
i += 1
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
index 0e41f3c7aa6b8..f9aaea50108bb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
@@ -170,7 +170,7 @@ class UnivocityParser(
name: String,
nullable: Boolean,
options: CSVOptions)(converter: ValueConverter): Any = {
- if (datum == options.nullValue || datum == null) {
+ if (options.nullValue.contains(datum) || datum == null) {
if (!nullable) {
throw new RuntimeException(s"null value found but field $name is not nullable.")
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
index 6057a795c8bf5..7fc475d69b5b4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
@@ -71,7 +71,8 @@ trait RelationProvider {
* Returns a new base relation with the given parameters.
*
* @note The parameters' keywords are case insensitive and this insensitivity is enforced
- * by the Map that is passed to the function.
+ * by the Map that is passed to the function. Also, the value of the Map can be a JSON
+ * array string if users set an array via option APIs.
*/
def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation
}
@@ -102,7 +103,8 @@ trait SchemaRelationProvider {
* Returns a new base relation with the given parameters and user defined schema.
*
* @note The parameters' keywords are case insensitive and this insensitivity is enforced
- * by the Map that is passed to the function.
+ * by the Map that is passed to the function. Also, the value of the Map can be a JSON
+ * array string if users set an array via option APIs.
*/
def createRelation(
sqlContext: SQLContext,
@@ -122,7 +124,12 @@ trait StreamSourceProvider {
/**
* Returns the name and schema of the source that can be used to continually read data.
+ *
* @since 2.0.0
+ *
+ * @note The parameters' keywords are case insensitive and this insensitivity is enforced
+ * by the Map that is passed to the function. Also, the value of the Map can be a JSON
+ * array string if users set an array via option APIs.
*/
def sourceSchema(
sqlContext: SQLContext,
@@ -131,7 +138,13 @@ trait StreamSourceProvider {
parameters: Map[String, String]): (String, StructType)
/**
+ * Returns a source that can be used to continually read data.
+ *
* @since 2.0.0
+ *
+ * @note The parameters' keywords are case insensitive and this insensitivity is enforced
+ * by the Map that is passed to the function. Also, the value of the Map can be a JSON
+ * array string if users set an array via option APIs.
*/
def createSource(
sqlContext: SQLContext,
@@ -150,6 +163,13 @@ trait StreamSourceProvider {
@Experimental
@InterfaceStability.Unstable
trait StreamSinkProvider {
+ /**
+ * Returns a sink that can be used to continually write data.
+ *
+ * @note The parameters' keywords are case insensitive and this insensitivity is enforced
+ * by the Map that is passed to the function. Also, the value of the Map can be a JSON
+ * array string if users set an array via option APIs.
+ */
def createSink(
sqlContext: SQLContext,
parameters: Map[String, String],
@@ -172,6 +192,10 @@ trait CreatableRelationProvider {
* @return Relation with a known schema
*
* @since 1.3.0
+ *
+ * @note The parameters' keywords are case insensitive and this insensitivity is enforced
+ * by the Map that is passed to the function. Also, the value of the Map can be a JSON
+ * array string if users set an array via option APIs.
*/
def createRelation(
sqlContext: SQLContext,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
index a42e28053a96a..42d54cab507e8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
@@ -21,6 +21,9 @@ import java.util.Locale
import scala.collection.JavaConverters._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods.{compact, render}
+
import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, SparkSession}
@@ -108,6 +111,22 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
*/
def option(key: String, value: Double): DataStreamReader = option(key, value.toString)
+ /**
+ * (Scala-specific) Adds an input option for the underlying data source.
+ *
+ * @since 2.3.0
+ */
+ def option(key: String, value: Seq[String]): DataStreamReader = {
+ option(key, compact(render(value)))
+ }
+
+ /**
+ * Adds an input option for the underlying data source.
+ *
+ * @since 2.3.0
+ */
+ def option(key: String, value: Array[String]): DataStreamReader = option(key, value.toSeq)
+
/**
* (Scala-specific) Adds input options for the underlying data source.
*
@@ -140,6 +159,15 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
this
}
+ /**
+ * Un-sets an input option for the underlying data source.
+ *
+ * @since 2.3.0
+ */
+ def unsetOption(key: String): DataStreamReader = {
+ this.extraOptions.remove(key)
+ this
+ }
/**
* Loads input data stream in as a `DataFrame`, for data streams that don't require a path
@@ -259,7 +287,9 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
* `ignoreTrailingWhiteSpace` (default `false`): a flag indicating whether or not trailing
* whitespaces from values being read should be skipped.
* `nullValue` (default empty string): sets the string representation of a null value. Since
- * 2.0.1, this applies to all supported types including the string type.
+ * 2.0.1, this applies to all supported types including the string type.
+ * An array of strings to represent null values can be set for this option. For example,
+ * `spark.readStream.format("csv").option("nullValue", Seq("", "null"))`.
* `nanValue` (default `NaN`): sets the string representation of a non-number" value.
* `positiveInf` (default `Inf`): sets the string representation of a positive infinity
* value.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
index 14e7df672cc58..dbc07e6b51011 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
@@ -21,6 +21,9 @@ import java.util.Locale
import scala.collection.JavaConverters._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods.{compact, render}
+
import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.sql.{AnalysisException, Dataset, ForeachWriter}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
@@ -179,6 +182,23 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
*/
def option(key: String, value: Double): DataStreamWriter[T] = option(key, value.toString)
+
+ /**
+ * (Scala-specific) Adds an output option for the underlying data source.
+ *
+ * @since 2.3.0
+ */
+ def option(key: String, value: Seq[String]): DataStreamWriter[T] = {
+ option(key, compact(render(value)))
+ }
+
+ /**
+ * Adds an output option for the underlying data source.
+ *
+ * @since 2.3.0
+ */
+ def option(key: String, value: Array[String]): DataStreamWriter[T] = option(key, value.toSeq)
+
/**
* (Scala-specific) Adds output options for the underlying data source.
*
@@ -211,6 +231,16 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
this
}
+ /**
+ * Un-sets an option option for the underlying data source.
+ *
+ * @since 2.3.0
+ */
+ def unsetOption(key: String): DataStreamWriter[T] = {
+ this.extraOptions.remove(key)
+ this
+ }
+
/**
* Starts the execution of the streaming query, which will continually output results to the given
* path as new data arrives. The returned [[StreamingQuery]] object can be used to interact with
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameReaderWriterSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameReaderWriterSuite.java
index 7babf7573c075..0b628ea124b33 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameReaderWriterSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameReaderWriterSuite.java
@@ -70,6 +70,7 @@ public void testOptionsAPI() {
.option("b", 1)
.option("c", 1.0)
.option("d", true)
+ .option("e", new String[]{"1", "2"})
.options(map)
.text()
.write()
@@ -77,6 +78,7 @@ public void testOptionsAPI() {
.option("b", 1)
.option("c", 1.0)
.option("d", true)
+ .option("e", new String[]{"1", "2"})
.options(map)
.format("org.apache.spark.sql.test")
.save();
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala
index 4ee38215f5973..d58849fc3bec2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala
@@ -22,6 +22,9 @@ import java.util.Locale
import scala.reflect.{classTag, ClassTag}
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods.{compact, render}
+
import org.apache.spark.sql.{AnalysisException, SaveMode}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
@@ -1030,15 +1033,16 @@ class DDLParserSuite extends PlanTest with SharedSQLContext {
"""
|CREATE DATABASE database_name
|LOCATION '/home/user/db'
- |WITH DBPROPERTIES ('a'=1, 'b'=0.1, 'c'=TRUE)
+ |WITH DBPROPERTIES ('a'=1, 'b'=0.1, 'c'=TRUE, 'd'=[1, 0.1, TRUE, "e"])
""".stripMargin
val parsed = parser.parsePlan(sql)
+ val dValue = compact(render(Seq(1.toString, 0.1.toString, true.toString, "e")))
val expected = CreateDatabaseCommand(
"database_name",
ifNotExists = false,
Some("/home/user/db"),
None,
- Map("a" -> "1", "b" -> "0.1", "c" -> "true"))
+ Map("a" -> "1", "b" -> "0.1", "c" -> "true", "d" -> dValue))
comparePlans(parsed, expected)
}
@@ -1047,12 +1051,13 @@ class DDLParserSuite extends PlanTest with SharedSQLContext {
val sql =
"""
|ALTER TABLE table_name
- |SET TBLPROPERTIES ('a' = 1, 'b' = 0.1, 'c' = TRUE)
+ |SET TBLPROPERTIES ('a' = 1, 'b' = 0.1, 'c' = TRUE, 'd'=[1,0.1,TRUE,"e"])
""".stripMargin
val parsed = parser.parsePlan(sql)
+ val dValue = compact(render(Seq(1.toString, 0.1.toString, true.toString, "e")))
val expected = AlterTableSetPropertiesCommand(
TableIdentifier("table_name"),
- Map("a" -> "1", "b" -> "0.1", "c" -> "true"),
+ Map("a" -> "1", "b" -> "0.1", "c" -> "true", "d" -> dValue),
isView = false)
comparePlans(parsed, expected)
@@ -1062,14 +1067,15 @@ class DDLParserSuite extends PlanTest with SharedSQLContext {
val sql =
"""
|CREATE TABLE table_name USING json
- |OPTIONS (a 1, b 0.1, c TRUE)
+ |OPTIONS (a 1, b 0.1, c TRUE, d [1, 0.1,TRUE, "e"])
""".stripMargin
+ val dValue = compact(render(Seq(1.toString, 0.1.toString, true.toString, "e")))
val expectedTableDesc = CatalogTable(
identifier = TableIdentifier("table_name"),
tableType = CatalogTableType.MANAGED,
storage = CatalogStorageFormat.empty.copy(
- properties = Map("a" -> "1", "b" -> "0.1", "c" -> "true")
+ properties = Map("a" -> "1", "b" -> "0.1", "c" -> "true", "d" -> dValue)
),
schema = new StructType,
provider = Some("json")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala
index 661742087112f..386aae86bb323 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala
@@ -17,6 +17,11 @@
package org.apache.spark.sql.execution.datasources.csv
+import scala.util.Random
+
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods.{compact, render}
+
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types._
@@ -92,17 +97,23 @@ class CSVInferSchemaSuite extends SparkFunSuite {
}
test("Null fields are handled properly when a nullValue is specified") {
- var options = new CSVOptions(Map("nullValue" -> "null"), "GMT")
- assert(CSVInferSchema.inferField(NullType, "null", options) == NullType)
- assert(CSVInferSchema.inferField(StringType, "null", options) == StringType)
- assert(CSVInferSchema.inferField(LongType, "null", options) == LongType)
-
- options = new CSVOptions(Map("nullValue" -> "\\N"), "GMT")
- assert(CSVInferSchema.inferField(IntegerType, "\\N", options) == IntegerType)
- assert(CSVInferSchema.inferField(DoubleType, "\\N", options) == DoubleType)
- assert(CSVInferSchema.inferField(TimestampType, "\\N", options) == TimestampType)
- assert(CSVInferSchema.inferField(BooleanType, "\\N", options) == BooleanType)
- assert(CSVInferSchema.inferField(DecimalType(1, 1), "\\N", options) == DecimalType(1, 1))
+ val types = Seq(NullType, StringType, LongType, IntegerType,
+ DoubleType, TimestampType, BooleanType, DecimalType(1, 1))
+ types.foreach { t =>
+ Seq("null", "\\N").foreach { v =>
+ val options = new CSVOptions(Map("nullValue" -> v), "GMT")
+ assert(CSVInferSchema.inferField(t, v, options) == t)
+ }
+ }
+
+ // nullable field with multiple nullValue option.
+ val nullValues = Seq("abc", "", "123", "null")
+ val nullValuesStr = compact(render(nullValues))
+ types.foreach { t =>
+ val options = new CSVOptions(Map("nullValue" -> nullValuesStr), "GMT")
+ val nullVal = nullValues(Random.nextInt(nullValues.length))
+ assert(CSVInferSchema.inferField(t, nullVal, options) == t)
+ }
}
test("Merging Nulltypes should yield Nulltype.") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
index 243a55cffd47f..0f680a0ce0128 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
@@ -596,6 +596,35 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
assert(results(2).toSeq === Array(null, "Chevy", "Volt", null, null))
}
+ test("multiple nullValue option for reading") {
+ val cars = spark.read
+ .format("csv")
+ .option("header", "true")
+ .option("nullValue", Seq("2012", "Tesla", "null"))
+ .load(testFile(carsNullFile))
+
+ verifyCars(cars, withHeader = true, checkValues = false)
+ val results = cars.collect()
+ assert(results(0).toSeq === Array(null, null, "S", null, null))
+ assert(results(2).toSeq === Array(null, "Chevy", "Volt", null, null))
+ }
+
+ test("DDL multiple nullValue option for reading") {
+ spark.sql(
+ s"""
+ |CREATE TEMPORARY TABLE carsTable USING csv
+ |OPTIONS (path "${testFile(carsNullFile)}", header "true",
+ |nullValue [2012, 'Tesla', 'null'])
+ """.stripMargin.replaceAll("\n", " "))
+
+ val cars = spark.sql("SELECT * FROM carsTable")
+
+ verifyCars(cars, withHeader = true, checkValues = false)
+ val results = cars.collect()
+ assert(results(0).toSeq === Array(null, null, "S", null, null))
+ assert(results(2).toSeq === Array(null, "Chevy", "Volt", null, null))
+ }
+
test("save csv with compression codec option") {
withTempDir { dir =>
val csvDir = new File(dir, "csv").getCanonicalPath
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala
index efbf73534bd19..37505508ea9c1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala
@@ -20,6 +20,11 @@ package org.apache.spark.sql.execution.datasources.csv
import java.math.BigDecimal
import java.util.Locale
+import scala.util.Random
+
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods.{compact, render}
+
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
@@ -78,6 +83,16 @@ class UnivocityParserSuite extends SparkFunSuite {
assert(message.contains("null value found but field _1 is not nullable."))
}
+ // nullable field with multiple nullValue option.
+ val nullValues = Seq("abc", "", "123", "null")
+ val nullValuesStr = compact(render(nullValues))
+ types.foreach { t =>
+ val options = new CSVOptions(Map("nullValue" -> nullValuesStr), "GMT")
+ val converter =
+ parser.makeConverter("_1", t, nullable = true, options = options)
+ assertNull(converter.apply(nullValues(Random.nextInt(nullValues.length))))
+ }
+
// If nullValue is different with empty string, then, empty string should not be casted into
// null.
Seq(true, false).foreach { b =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala
index aa163d2211c38..d41fd5b63f0e5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala
@@ -24,6 +24,8 @@ import java.util.concurrent.TimeUnit
import scala.concurrent.duration._
import org.apache.hadoop.fs.Path
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods.{compact, render}
import org.mockito.Matchers.{any, eq => meq}
import org.mockito.Mockito._
import org.scalatest.BeforeAndAfter
@@ -162,11 +164,14 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter {
.option("opt1", "1")
.options(Map("opt2" -> "2"))
.options(map)
+ .option("opt4", "4")
+ .unsetOption("opt4")
.load()
assert(LastOptions.parameters("opt1") == "1")
assert(LastOptions.parameters("opt2") == "2")
assert(LastOptions.parameters("opt3") == "3")
+ assert(!LastOptions.parameters.contains("opt4"))
LastOptions.clear()
@@ -176,12 +181,37 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter {
.options(Map("opt2" -> "2"))
.options(map)
.option("checkpointLocation", newMetadataDir)
+ .option("opt4", "4")
+ .unsetOption("opt4")
.start()
.stop()
assert(LastOptions.parameters("opt1") == "1")
assert(LastOptions.parameters("opt2") == "2")
assert(LastOptions.parameters("opt3") == "3")
+ assert(!LastOptions.parameters.contains("opt4"))
+ }
+
+ test("options - array") {
+ val expected = compact(render(Seq("1", "0.1", "TRUE", "e")))
+
+ val df = spark.readStream
+ .format("org.apache.spark.sql.streaming.test")
+ .option("opt1", Seq("1", "0.1", "TRUE", "e"))
+ .load()
+
+ assert(LastOptions.parameters("opt1") == expected)
+
+ LastOptions.clear()
+
+ df.writeStream
+ .format("org.apache.spark.sql.streaming.test")
+ .option("opt1", Seq("1", "0.1", "TRUE", "e"))
+ .option("checkpointLocation", newMetadataDir)
+ .start()
+ .stop()
+
+ assert(LastOptions.parameters("opt1") == expected)
}
test("partitioning") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala
index 569bac156b531..83d51cf45c40c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala
@@ -21,6 +21,8 @@ import java.io.File
import java.util.Locale
import java.util.concurrent.ConcurrentLinkedQueue
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods.{compact, render}
import org.scalatest.BeforeAndAfter
import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage
@@ -188,11 +190,14 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be
.option("opt1", "1")
.options(Map("opt2" -> "2"))
.options(map)
+ .option("opt4", "4")
+ .unsetOption("opt4")
.load()
assert(LastOptions.parameters("opt1") == "1")
assert(LastOptions.parameters("opt2") == "2")
assert(LastOptions.parameters("opt3") == "3")
+ assert(!LastOptions.parameters.contains("opt4"))
LastOptions.clear()
@@ -201,11 +206,34 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be
.option("opt1", "1")
.options(Map("opt2" -> "2"))
.options(map)
+ .option("opt4", "4")
+ .unsetOption("opt4")
.save()
assert(LastOptions.parameters("opt1") == "1")
assert(LastOptions.parameters("opt2") == "2")
assert(LastOptions.parameters("opt3") == "3")
+ assert(!LastOptions.parameters.contains("opt4"))
+ }
+
+ test("options - array") {
+ val expected = compact(render(Seq("1", "0.1", "TRUE", "e")))
+
+ val df = spark.read
+ .format("org.apache.spark.sql.test")
+ .option("opt1", Seq("1", "0.1", "TRUE", "e"))
+ .load()
+
+ assert(LastOptions.parameters("opt1") == expected)
+
+ LastOptions.clear()
+
+ df.write
+ .format("org.apache.spark.sql.test")
+ .option("opt1", Seq("1", "0.1", "TRUE", "e"))
+ .save()
+
+ assert(LastOptions.parameters("opt1") == expected)
}
test("save mode") {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala
index 2ec593b95c9b6..9a800765e7198 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala
@@ -65,11 +65,13 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest with Predicat
test("test hadoop conf option propagation") {
withTempPath { file =>
+ val nullVal: String = null
+
// Test write side
val df = spark.range(10).selectExpr("cast(id as string)")
df.write
.option("some-random-write-option", "hahah-WRITE")
- .option("some-null-value-option", null) // test null robustness
+ .option("some-null-value-option", nullVal) // test null robustness
.option("dataSchema", df.schema.json)
.format(dataSourceName).save(file.getAbsolutePath)
assert(SimpleTextRelation.lastHadoopConf.get.get("some-random-write-option") == "hahah-WRITE")
@@ -77,7 +79,7 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest with Predicat
// Test read side
val df1 = spark.read
.option("some-random-read-option", "hahah-READ")
- .option("some-null-value-option", null) // test null robustness
+ .option("some-null-value-option", nullVal) // test null robustness
.option("dataSchema", df.schema.json)
.format(dataSourceName)
.load(file.getAbsolutePath)