diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index cb847a0420311..887930e9a44c4 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -121,7 +121,12 @@ def option(self, key, value): in the JSON/CSV datasources or partition values. If it isn't set, it uses the default value, session local timezone. """ - self._jreader = self._jreader.option(key, to_str(value)) + if isinstance(value, (list, tuple)): + gateway = self._spark._sc._gateway + jvalues = utils.toJArray(gateway, gateway.jvm.java.lang.String, value) + self._jreader = self._jreader.option(key, jvalues) + else: + self._jreader = self._jreader.option(key, to_str(value)) return self @since(1.4) @@ -134,7 +139,14 @@ def options(self, **options): If it isn't set, it uses the default value, session local timezone. """ for k in options: - self._jreader = self._jreader.option(k, to_str(options[k])) + self.option(k, options[k]) + return self + + @since(2.3) + def unsetOption(self, key): + """Un-sets the option given to the key for the underlying data source. + """ + self._jreader = self._jreader.unsetOption(key) return self @since(1.4) @@ -322,6 +334,7 @@ def text(self, paths): paths = [paths] return self._df(self._jreader.text(self._spark._sc._jvm.PythonUtils.toSeq(paths))) + @ignore_unicode_prefix @since(2.0) def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=None, comment=None, header=None, inferSchema=None, ignoreLeadingWhiteSpace=None, @@ -362,7 +375,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non uses the default value, ``false``. :param nullValue: sets the string representation of a null value. If None is set, it uses the default value, empty string. Since 2.0.1, this ``nullValue`` param - applies to all supported types including the string type. + applies to all supported types including the string type. A list or tuple + of strings to represent null values can be set for this option. :param nanValue: sets the string representation of a non-number value. If None is set, it uses the default value, ``NaN``. :param positiveInf: sets the string representation of a positive infinity value. If None @@ -408,6 +422,10 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non >>> df = spark.read.csv('python/test_support/sql/ages.csv') >>> df.dtypes [('_c0', 'string'), ('_c1', 'string')] + + >>> df = spark.read.csv('python/test_support/sql/ages.csv', nullValue=['Tom', 'Joe']) + >>> df.collect() + [Row(_c0=None, _c1=u'20'), Row(_c0=None, _c1=u'30'), Row(_c0=u'Hyukjin', _c1=u'25')] """ self._set_opts( schema=schema, sep=sep, encoding=encoding, quote=quote, escape=escape, comment=comment, @@ -544,7 +562,12 @@ def option(self, key, value): timestamps in the JSON/CSV datasources or partition values. If it isn't set, it uses the default value, session local timezone. """ - self._jwrite = self._jwrite.option(key, to_str(value)) + if isinstance(value, (list, tuple)): + gateway = self._spark._sc._gateway + jvalues = utils.toJArray(gateway, gateway.jvm.java.lang.String, value) + self._jwrite = self._jwrite.option(key, jvalues) + else: + self._jwrite = self._jwrite.option(key, to_str(value)) return self @since(1.4) @@ -557,7 +580,14 @@ def options(self, **options): If it isn't set, it uses the default value, session local timezone. """ for k in options: - self._jwrite = self._jwrite.option(k, to_str(options[k])) + self.option(k, options[k]) + return self + + @since(2.3) + def unsetOption(self, key): + """Un-sets the option given to the key for the underlying data source. + """ + self._jwrite = self._jwrite.unsetOption(key) return self @since(1.4) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 0cf702143c773..97d83efd439f6 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -349,7 +349,12 @@ def option(self, key, value): >>> s = spark.readStream.option("x", 1) """ - self._jreader = self._jreader.option(key, to_str(value)) + if isinstance(value, (list, tuple)): + gateway = self._spark._sc._gateway + jvalues = utils.toJArray(gateway, gateway.jvm.java.lang.String, value) + self._jreader = self._jreader.option(key, jvalues) + else: + self._jreader = self._jreader.option(key, to_str(value)) return self @since(2.0) @@ -366,7 +371,16 @@ def options(self, **options): >>> s = spark.readStream.options(x="1", y=2) """ for k in options: - self._jreader = self._jreader.option(k, to_str(options[k])) + self.option(k, options[k]) + return self + + @since(2.3) + def unsetOption(self, key): + """Un-sets the option given to the key for the underlying data source. + + .. note:: Evolving. + """ + self._jreader = self._jreader.unsetOption(key) return self @since(2.0) @@ -579,7 +593,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non uses the default value, ``false``. :param nullValue: sets the string representation of a null value. If None is set, it uses the default value, empty string. Since 2.0.1, this ``nullValue`` param - applies to all supported types including the string type. + applies to all supported types including the string type. A list or tuple + of strings to represent null values can be set for this option. :param nanValue: sets the string representation of a non-number value. If None is set, it uses the default value, ``NaN``. :param positiveInf: sets the string representation of a positive infinity value. If None @@ -710,7 +725,12 @@ def option(self, key, value): .. note:: Evolving. """ - self._jwrite = self._jwrite.option(key, to_str(value)) + if isinstance(value, (list, tuple)): + gateway = self._spark._sc._gateway + jvalues = utils.toJArray(gateway, gateway.jvm.java.lang.String, value) + self._jwrite = self._jwrite.option(key, jvalues) + else: + self._jwrite = self._jwrite.option(key, to_str(value)) return self @since(2.0) @@ -725,7 +745,16 @@ def options(self, **options): .. note:: Evolving. """ for k in options: - self._jwrite = self._jwrite.option(k, to_str(options[k])) + self.option(k, options[k]) + return self + + @since(2.3) + def unsetOption(self, key): + """Un-sets the option given to the key for the underlying data source. + + .. note:: Evolving. + """ + self._jwrite = self._jwrite.unsetOption(key) return self @since(2.0) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 3d87ccfc03ddd..cd2d845e244cf 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1369,7 +1369,7 @@ def test_save_and_load(self): self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName) csvpath = os.path.join(tempfile.mkdtemp(), 'data') - df.write.option('quote', None).format('csv').save(csvpath) + df.write.unsetOption('quote').format('csv').save(csvpath) shutil.rmtree(tmpPath) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index f741dcfbf2002..ec15049f3b485 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -297,12 +297,21 @@ tablePropertyKey ; tablePropertyValue + : propertyValue + | propertyArrayValue + ; + +propertyValue : INTEGER_VALUE | DECIMAL_VALUE | booleanValue | STRING ; +propertyArrayValue + : '[' propertyValue (',' propertyValue)* ']' + ; + constantList : '(' constant (',' constant)* ')' ; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 4f375e59c34d4..f463b832a0003 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -21,6 +21,9 @@ import java.util.{Locale, Properties} import scala.collection.JavaConverters._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods.{compact, render} + import org.apache.spark.Partition import org.apache.spark.annotation.InterfaceStability import org.apache.spark.api.java.JavaRDD @@ -116,6 +119,22 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { */ def option(key: String, value: Double): DataFrameReader = option(key, value.toString) + /** + * (Scala-specific) Adds an input option for the underlying data source. + * + * @since 2.3.0 + */ + def option(key: String, value: Seq[String]): DataFrameReader = { + option(key, compact(render(value))) + } + + /** + * Adds an input option for the underlying data source. + * + * @since 2.3.0 + */ + def option(key: String, value: Array[String]): DataFrameReader = option(key, value.toSeq) + /** * (Scala-specific) Adds input options for the underlying data source. * @@ -148,6 +167,16 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { this } + /** + * Un-sets an input option for the underlying data source. + * + * @since 2.3.0 + */ + def unsetOption(key: String): DataFrameReader = { + this.extraOptions.remove(key) + this + } + /** * Loads input in as a `DataFrame`, for data sources that don't require a path (e.g. external * key-value stores). @@ -502,7 +531,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `ignoreTrailingWhiteSpace` (default `false`): a flag indicating whether or not trailing * whitespaces from values being read should be skipped.
  • *
  • `nullValue` (default empty string): sets the string representation of a null value. Since - * 2.0.1, this applies to all supported types including the string type.
  • + * 2.0.1, this applies to all supported types including the string type. + * An array of strings to represent null values can be set for this option. For example, + * `spark.read.format("csv").option("nullValue", Seq("", "null"))`. *
  • `nanValue` (default `NaN`): sets the string representation of a non-number" value.
  • *
  • `positiveInf` (default `Inf`): sets the string representation of a positive infinity * value.
  • diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 07347d2748544..2329245ed2518 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -21,6 +21,9 @@ import java.util.{Locale, Properties} import scala.collection.JavaConverters._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods.{compact, render} + import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedRelation} @@ -125,6 +128,22 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { */ def option(key: String, value: Double): DataFrameWriter[T] = option(key, value.toString) + /** + * (Scala-specific) Adds an output option for the underlying data source. + * + * @since 2.3.0 + */ + def option(key: String, value: Seq[String]): DataFrameWriter[T] = { + option(key, compact(render(value))) + } + + /** + * Adds an output option for the underlying data source. + * + * @since 2.3.0 + */ + def option(key: String, value: Array[String]): DataFrameWriter[T] = option(key, value.toSeq) + /** * (Scala-specific) Adds output options for the underlying data source. * @@ -157,6 +176,16 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { this } + /** + * Un-sets an output option for the underlying data source. + * + * @since 2.3.0 + */ + def unsetOption(key: String): DataFrameWriter[T] = { + this.extraOptions.remove(key) + this + } + /** * Partitions the output by the given columns on the file system. If specified, the output is * laid out on the file system similar to Hive's partitioning scheme. As an example, when we diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index d3f6ab5654689..25469c8bbd4d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -23,6 +23,8 @@ import scala.collection.JavaConverters._ import org.antlr.v4.runtime.{ParserRuleContext, Token} import org.antlr.v4.runtime.tree.TerminalNode +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods.{compact, render} import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} @@ -587,6 +589,21 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * the property value based on whether its a string, integer, boolean or decimal literal. */ override def visitTablePropertyValue(value: TablePropertyValueContext): String = { + if (value == null) { + null + } else if (value.propertyValue != null) { + visitPropertyValue(value.propertyValue) + } else if (value.propertyArrayValue != null) { + val values = value.propertyArrayValue.propertyValue.asScala.map { v => + visitPropertyValue(v) + } + compact(render(values)) + } else { + value.getText + } + } + + override def visitPropertyValue(value: PropertyValueContext): String = { if (value == null) { null } else if (value.STRING != null) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala index b64d71bb4eef2..b611e3b631fdd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala @@ -79,7 +79,7 @@ private[csv] object CSVInferSchema { * point checking if it is an Int, as the final type must be Double or higher. */ def inferField(typeSoFar: DataType, field: String, options: CSVOptions): DataType = { - if (field == null || field.isEmpty || field == options.nullValue) { + if (field == null || field.isEmpty || options.nullValue.contains(field)) { typeSoFar } else { typeSoFar match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index a13a5a34b4a84..1b5b59843cacd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -20,8 +20,12 @@ package org.apache.spark.sql.execution.datasources.csv import java.nio.charset.StandardCharsets import java.util.{Locale, TimeZone} +import scala.util.Try + import com.univocity.parsers.csv.{CsvParserSettings, CsvWriterSettings, UnescapedQuoteHandling} import org.apache.commons.lang3.time.FastDateFormat +import org.json4s._ +import org.json4s.jackson.JsonMethods.parse import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.util._ @@ -104,7 +108,17 @@ class CSVOptions( val columnNameOfCorruptRecord = parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord) - val nullValue = parameters.getOrElse("nullValue", "") + val nullValue: Array[String] = parameters.get("nullValue").map { str => + Try { + implicit val formats = DefaultFormats + val arr = parse(str).extract[Array[String]] + // If the input is just a string not formatted as a json array, it can be an empty array. + // In this case throws an exception so that the string value is used as is for backwards + // compatibility. + require(arr.nonEmpty) + arr + }.getOrElse(Array(str)) + }.getOrElse(Array("")) val nanValue = parameters.getOrElse("nanValue", "NaN") @@ -151,8 +165,8 @@ class CSVOptions( format.setComment(comment) writerSettings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceFlagInWrite) writerSettings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceFlagInWrite) - writerSettings.setNullValue(nullValue) - writerSettings.setEmptyValue(nullValue) + writerSettings.setNullValue(nullValue(0)) + writerSettings.setEmptyValue(nullValue(0)) writerSettings.setSkipEmptyLines(true) writerSettings.setQuoteAllFields(quoteAll) writerSettings.setQuoteEscapingEnabled(escapeQuotes) @@ -171,7 +185,7 @@ class CSVOptions( settings.setReadInputOnSeparateThread(false) settings.setInputBufferSize(inputBufferSize) settings.setMaxColumns(maxColumns) - settings.setNullValue(nullValue) + settings.setNullValue(nullValue(0)) settings.setMaxCharsPerColumn(maxCharsPerColumn) settings.setUnescapedQuoteHandling(UnescapedQuoteHandling.STOP_AT_DELIMITER) settings diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityGenerator.scala index 4082a0df8ba75..dbdaae1162109 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityGenerator.scala @@ -65,7 +65,7 @@ private[csv] class UnivocityGenerator( if (!row.isNullAt(i)) { values(i) = valueConverters(i).apply(row, i) } else { - values(i) = options.nullValue + values(i) = options.nullValue(0) } i += 1 } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index 0e41f3c7aa6b8..f9aaea50108bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -170,7 +170,7 @@ class UnivocityParser( name: String, nullable: Boolean, options: CSVOptions)(converter: ValueConverter): Any = { - if (datum == options.nullValue || datum == null) { + if (options.nullValue.contains(datum) || datum == null) { if (!nullable) { throw new RuntimeException(s"null value found but field $name is not nullable.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 6057a795c8bf5..7fc475d69b5b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -71,7 +71,8 @@ trait RelationProvider { * Returns a new base relation with the given parameters. * * @note The parameters' keywords are case insensitive and this insensitivity is enforced - * by the Map that is passed to the function. + * by the Map that is passed to the function. Also, the value of the Map can be a JSON + * array string if users set an array via option APIs. */ def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation } @@ -102,7 +103,8 @@ trait SchemaRelationProvider { * Returns a new base relation with the given parameters and user defined schema. * * @note The parameters' keywords are case insensitive and this insensitivity is enforced - * by the Map that is passed to the function. + * by the Map that is passed to the function. Also, the value of the Map can be a JSON + * array string if users set an array via option APIs. */ def createRelation( sqlContext: SQLContext, @@ -122,7 +124,12 @@ trait StreamSourceProvider { /** * Returns the name and schema of the source that can be used to continually read data. + * * @since 2.0.0 + * + * @note The parameters' keywords are case insensitive and this insensitivity is enforced + * by the Map that is passed to the function. Also, the value of the Map can be a JSON + * array string if users set an array via option APIs. */ def sourceSchema( sqlContext: SQLContext, @@ -131,7 +138,13 @@ trait StreamSourceProvider { parameters: Map[String, String]): (String, StructType) /** + * Returns a source that can be used to continually read data. + * * @since 2.0.0 + * + * @note The parameters' keywords are case insensitive and this insensitivity is enforced + * by the Map that is passed to the function. Also, the value of the Map can be a JSON + * array string if users set an array via option APIs. */ def createSource( sqlContext: SQLContext, @@ -150,6 +163,13 @@ trait StreamSourceProvider { @Experimental @InterfaceStability.Unstable trait StreamSinkProvider { + /** + * Returns a sink that can be used to continually write data. + * + * @note The parameters' keywords are case insensitive and this insensitivity is enforced + * by the Map that is passed to the function. Also, the value of the Map can be a JSON + * array string if users set an array via option APIs. + */ def createSink( sqlContext: SQLContext, parameters: Map[String, String], @@ -172,6 +192,10 @@ trait CreatableRelationProvider { * @return Relation with a known schema * * @since 1.3.0 + * + * @note The parameters' keywords are case insensitive and this insensitivity is enforced + * by the Map that is passed to the function. Also, the value of the Map can be a JSON + * array string if users set an array via option APIs. */ def createRelation( sqlContext: SQLContext, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index a42e28053a96a..42d54cab507e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -21,6 +21,9 @@ import java.util.Locale import scala.collection.JavaConverters._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods.{compact, render} + import org.apache.spark.annotation.InterfaceStability import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, SparkSession} @@ -108,6 +111,22 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo */ def option(key: String, value: Double): DataStreamReader = option(key, value.toString) + /** + * (Scala-specific) Adds an input option for the underlying data source. + * + * @since 2.3.0 + */ + def option(key: String, value: Seq[String]): DataStreamReader = { + option(key, compact(render(value))) + } + + /** + * Adds an input option for the underlying data source. + * + * @since 2.3.0 + */ + def option(key: String, value: Array[String]): DataStreamReader = option(key, value.toSeq) + /** * (Scala-specific) Adds input options for the underlying data source. * @@ -140,6 +159,15 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo this } + /** + * Un-sets an input option for the underlying data source. + * + * @since 2.3.0 + */ + def unsetOption(key: String): DataStreamReader = { + this.extraOptions.remove(key) + this + } /** * Loads input data stream in as a `DataFrame`, for data streams that don't require a path @@ -259,7 +287,9 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
  • `ignoreTrailingWhiteSpace` (default `false`): a flag indicating whether or not trailing * whitespaces from values being read should be skipped.
  • *
  • `nullValue` (default empty string): sets the string representation of a null value. Since - * 2.0.1, this applies to all supported types including the string type.
  • + * 2.0.1, this applies to all supported types including the string type. + * An array of strings to represent null values can be set for this option. For example, + * `spark.readStream.format("csv").option("nullValue", Seq("", "null"))`. *
  • `nanValue` (default `NaN`): sets the string representation of a non-number" value.
  • *
  • `positiveInf` (default `Inf`): sets the string representation of a positive infinity * value.
  • diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 14e7df672cc58..dbc07e6b51011 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -21,6 +21,9 @@ import java.util.Locale import scala.collection.JavaConverters._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods.{compact, render} + import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.{AnalysisException, Dataset, ForeachWriter} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes @@ -179,6 +182,23 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { */ def option(key: String, value: Double): DataStreamWriter[T] = option(key, value.toString) + + /** + * (Scala-specific) Adds an output option for the underlying data source. + * + * @since 2.3.0 + */ + def option(key: String, value: Seq[String]): DataStreamWriter[T] = { + option(key, compact(render(value))) + } + + /** + * Adds an output option for the underlying data source. + * + * @since 2.3.0 + */ + def option(key: String, value: Array[String]): DataStreamWriter[T] = option(key, value.toSeq) + /** * (Scala-specific) Adds output options for the underlying data source. * @@ -211,6 +231,16 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { this } + /** + * Un-sets an option option for the underlying data source. + * + * @since 2.3.0 + */ + def unsetOption(key: String): DataStreamWriter[T] = { + this.extraOptions.remove(key) + this + } + /** * Starts the execution of the streaming query, which will continually output results to the given * path as new data arrives. The returned [[StreamingQuery]] object can be used to interact with diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameReaderWriterSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameReaderWriterSuite.java index 7babf7573c075..0b628ea124b33 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameReaderWriterSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameReaderWriterSuite.java @@ -70,6 +70,7 @@ public void testOptionsAPI() { .option("b", 1) .option("c", 1.0) .option("d", true) + .option("e", new String[]{"1", "2"}) .options(map) .text() .write() @@ -77,6 +78,7 @@ public void testOptionsAPI() { .option("b", 1) .option("c", 1.0) .option("d", true) + .option("e", new String[]{"1", "2"}) .options(map) .format("org.apache.spark.sql.test") .save(); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala index 4ee38215f5973..d58849fc3bec2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala @@ -22,6 +22,9 @@ import java.util.Locale import scala.reflect.{classTag, ClassTag} +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods.{compact, render} + import org.apache.spark.sql.{AnalysisException, SaveMode} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute @@ -1030,15 +1033,16 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { """ |CREATE DATABASE database_name |LOCATION '/home/user/db' - |WITH DBPROPERTIES ('a'=1, 'b'=0.1, 'c'=TRUE) + |WITH DBPROPERTIES ('a'=1, 'b'=0.1, 'c'=TRUE, 'd'=[1, 0.1, TRUE, "e"]) """.stripMargin val parsed = parser.parsePlan(sql) + val dValue = compact(render(Seq(1.toString, 0.1.toString, true.toString, "e"))) val expected = CreateDatabaseCommand( "database_name", ifNotExists = false, Some("/home/user/db"), None, - Map("a" -> "1", "b" -> "0.1", "c" -> "true")) + Map("a" -> "1", "b" -> "0.1", "c" -> "true", "d" -> dValue)) comparePlans(parsed, expected) } @@ -1047,12 +1051,13 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { val sql = """ |ALTER TABLE table_name - |SET TBLPROPERTIES ('a' = 1, 'b' = 0.1, 'c' = TRUE) + |SET TBLPROPERTIES ('a' = 1, 'b' = 0.1, 'c' = TRUE, 'd'=[1,0.1,TRUE,"e"]) """.stripMargin val parsed = parser.parsePlan(sql) + val dValue = compact(render(Seq(1.toString, 0.1.toString, true.toString, "e"))) val expected = AlterTableSetPropertiesCommand( TableIdentifier("table_name"), - Map("a" -> "1", "b" -> "0.1", "c" -> "true"), + Map("a" -> "1", "b" -> "0.1", "c" -> "true", "d" -> dValue), isView = false) comparePlans(parsed, expected) @@ -1062,14 +1067,15 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { val sql = """ |CREATE TABLE table_name USING json - |OPTIONS (a 1, b 0.1, c TRUE) + |OPTIONS (a 1, b 0.1, c TRUE, d [1, 0.1,TRUE, "e"]) """.stripMargin + val dValue = compact(render(Seq(1.toString, 0.1.toString, true.toString, "e"))) val expectedTableDesc = CatalogTable( identifier = TableIdentifier("table_name"), tableType = CatalogTableType.MANAGED, storage = CatalogStorageFormat.empty.copy( - properties = Map("a" -> "1", "b" -> "0.1", "c" -> "true") + properties = Map("a" -> "1", "b" -> "0.1", "c" -> "true", "d" -> dValue) ), schema = new StructType, provider = Some("json") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala index 661742087112f..386aae86bb323 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala @@ -17,6 +17,11 @@ package org.apache.spark.sql.execution.datasources.csv +import scala.util.Random + +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods.{compact, render} + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types._ @@ -92,17 +97,23 @@ class CSVInferSchemaSuite extends SparkFunSuite { } test("Null fields are handled properly when a nullValue is specified") { - var options = new CSVOptions(Map("nullValue" -> "null"), "GMT") - assert(CSVInferSchema.inferField(NullType, "null", options) == NullType) - assert(CSVInferSchema.inferField(StringType, "null", options) == StringType) - assert(CSVInferSchema.inferField(LongType, "null", options) == LongType) - - options = new CSVOptions(Map("nullValue" -> "\\N"), "GMT") - assert(CSVInferSchema.inferField(IntegerType, "\\N", options) == IntegerType) - assert(CSVInferSchema.inferField(DoubleType, "\\N", options) == DoubleType) - assert(CSVInferSchema.inferField(TimestampType, "\\N", options) == TimestampType) - assert(CSVInferSchema.inferField(BooleanType, "\\N", options) == BooleanType) - assert(CSVInferSchema.inferField(DecimalType(1, 1), "\\N", options) == DecimalType(1, 1)) + val types = Seq(NullType, StringType, LongType, IntegerType, + DoubleType, TimestampType, BooleanType, DecimalType(1, 1)) + types.foreach { t => + Seq("null", "\\N").foreach { v => + val options = new CSVOptions(Map("nullValue" -> v), "GMT") + assert(CSVInferSchema.inferField(t, v, options) == t) + } + } + + // nullable field with multiple nullValue option. + val nullValues = Seq("abc", "", "123", "null") + val nullValuesStr = compact(render(nullValues)) + types.foreach { t => + val options = new CSVOptions(Map("nullValue" -> nullValuesStr), "GMT") + val nullVal = nullValues(Random.nextInt(nullValues.length)) + assert(CSVInferSchema.inferField(t, nullVal, options) == t) + } } test("Merging Nulltypes should yield Nulltype.") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 243a55cffd47f..0f680a0ce0128 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -596,6 +596,35 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { assert(results(2).toSeq === Array(null, "Chevy", "Volt", null, null)) } + test("multiple nullValue option for reading") { + val cars = spark.read + .format("csv") + .option("header", "true") + .option("nullValue", Seq("2012", "Tesla", "null")) + .load(testFile(carsNullFile)) + + verifyCars(cars, withHeader = true, checkValues = false) + val results = cars.collect() + assert(results(0).toSeq === Array(null, null, "S", null, null)) + assert(results(2).toSeq === Array(null, "Chevy", "Volt", null, null)) + } + + test("DDL multiple nullValue option for reading") { + spark.sql( + s""" + |CREATE TEMPORARY TABLE carsTable USING csv + |OPTIONS (path "${testFile(carsNullFile)}", header "true", + |nullValue [2012, 'Tesla', 'null']) + """.stripMargin.replaceAll("\n", " ")) + + val cars = spark.sql("SELECT * FROM carsTable") + + verifyCars(cars, withHeader = true, checkValues = false) + val results = cars.collect() + assert(results(0).toSeq === Array(null, null, "S", null, null)) + assert(results(2).toSeq === Array(null, "Chevy", "Volt", null, null)) + } + test("save csv with compression codec option") { withTempDir { dir => val csvDir = new File(dir, "csv").getCanonicalPath diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala index efbf73534bd19..37505508ea9c1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala @@ -20,6 +20,11 @@ package org.apache.spark.sql.execution.datasources.csv import java.math.BigDecimal import java.util.Locale +import scala.util.Random + +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods.{compact, render} + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ @@ -78,6 +83,16 @@ class UnivocityParserSuite extends SparkFunSuite { assert(message.contains("null value found but field _1 is not nullable.")) } + // nullable field with multiple nullValue option. + val nullValues = Seq("abc", "", "123", "null") + val nullValuesStr = compact(render(nullValues)) + types.foreach { t => + val options = new CSVOptions(Map("nullValue" -> nullValuesStr), "GMT") + val converter = + parser.makeConverter("_1", t, nullable = true, options = options) + assertNull(converter.apply(nullValues(Random.nextInt(nullValues.length)))) + } + // If nullValue is different with empty string, then, empty string should not be casted into // null. Seq(true, false).foreach { b => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala index aa163d2211c38..d41fd5b63f0e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala @@ -24,6 +24,8 @@ import java.util.concurrent.TimeUnit import scala.concurrent.duration._ import org.apache.hadoop.fs.Path +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods.{compact, render} import org.mockito.Matchers.{any, eq => meq} import org.mockito.Mockito._ import org.scalatest.BeforeAndAfter @@ -162,11 +164,14 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { .option("opt1", "1") .options(Map("opt2" -> "2")) .options(map) + .option("opt4", "4") + .unsetOption("opt4") .load() assert(LastOptions.parameters("opt1") == "1") assert(LastOptions.parameters("opt2") == "2") assert(LastOptions.parameters("opt3") == "3") + assert(!LastOptions.parameters.contains("opt4")) LastOptions.clear() @@ -176,12 +181,37 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { .options(Map("opt2" -> "2")) .options(map) .option("checkpointLocation", newMetadataDir) + .option("opt4", "4") + .unsetOption("opt4") .start() .stop() assert(LastOptions.parameters("opt1") == "1") assert(LastOptions.parameters("opt2") == "2") assert(LastOptions.parameters("opt3") == "3") + assert(!LastOptions.parameters.contains("opt4")) + } + + test("options - array") { + val expected = compact(render(Seq("1", "0.1", "TRUE", "e"))) + + val df = spark.readStream + .format("org.apache.spark.sql.streaming.test") + .option("opt1", Seq("1", "0.1", "TRUE", "e")) + .load() + + assert(LastOptions.parameters("opt1") == expected) + + LastOptions.clear() + + df.writeStream + .format("org.apache.spark.sql.streaming.test") + .option("opt1", Seq("1", "0.1", "TRUE", "e")) + .option("checkpointLocation", newMetadataDir) + .start() + .stop() + + assert(LastOptions.parameters("opt1") == expected) } test("partitioning") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 569bac156b531..83d51cf45c40c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -21,6 +21,8 @@ import java.io.File import java.util.Locale import java.util.concurrent.ConcurrentLinkedQueue +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods.{compact, render} import org.scalatest.BeforeAndAfter import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage @@ -188,11 +190,14 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be .option("opt1", "1") .options(Map("opt2" -> "2")) .options(map) + .option("opt4", "4") + .unsetOption("opt4") .load() assert(LastOptions.parameters("opt1") == "1") assert(LastOptions.parameters("opt2") == "2") assert(LastOptions.parameters("opt3") == "3") + assert(!LastOptions.parameters.contains("opt4")) LastOptions.clear() @@ -201,11 +206,34 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be .option("opt1", "1") .options(Map("opt2" -> "2")) .options(map) + .option("opt4", "4") + .unsetOption("opt4") .save() assert(LastOptions.parameters("opt1") == "1") assert(LastOptions.parameters("opt2") == "2") assert(LastOptions.parameters("opt3") == "3") + assert(!LastOptions.parameters.contains("opt4")) + } + + test("options - array") { + val expected = compact(render(Seq("1", "0.1", "TRUE", "e"))) + + val df = spark.read + .format("org.apache.spark.sql.test") + .option("opt1", Seq("1", "0.1", "TRUE", "e")) + .load() + + assert(LastOptions.parameters("opt1") == expected) + + LastOptions.clear() + + df.write + .format("org.apache.spark.sql.test") + .option("opt1", Seq("1", "0.1", "TRUE", "e")) + .save() + + assert(LastOptions.parameters("opt1") == expected) } test("save mode") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala index 2ec593b95c9b6..9a800765e7198 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala @@ -65,11 +65,13 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest with Predicat test("test hadoop conf option propagation") { withTempPath { file => + val nullVal: String = null + // Test write side val df = spark.range(10).selectExpr("cast(id as string)") df.write .option("some-random-write-option", "hahah-WRITE") - .option("some-null-value-option", null) // test null robustness + .option("some-null-value-option", nullVal) // test null robustness .option("dataSchema", df.schema.json) .format(dataSourceName).save(file.getAbsolutePath) assert(SimpleTextRelation.lastHadoopConf.get.get("some-random-write-option") == "hahah-WRITE") @@ -77,7 +79,7 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest with Predicat // Test read side val df1 = spark.read .option("some-random-read-option", "hahah-READ") - .option("some-null-value-option", null) // test null robustness + .option("some-null-value-option", nullVal) // test null robustness .option("dataSchema", df.schema.json) .format(dataSourceName) .load(file.getAbsolutePath)