diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 96ff389faf4a0..c5122843ff202 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -274,6 +274,7 @@ exportMethods("%<=>%", "floor", "format_number", "format_string", + "from_csv", "from_json", "from_unixtime", "from_utc_timestamp", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 6a8fef5aa7b22..d2ca1d6c00bb4 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -188,6 +188,7 @@ NULL #' \item \code{to_json}: it is the column containing the struct, array of the structs, #' the map or array of maps. #' \item \code{from_json}: it is the column containing the JSON string. +#' \item \code{from_csv}: it is the column containing the CSV string. #' } #' @param y Column to compute on. #' @param value A value to compute on. @@ -196,6 +197,13 @@ NULL #' \item \code{array_position}: a value to locate in the given array. #' \item \code{array_remove}: a value to remove in the given array. #' } +#' @param schema +#' \itemize{ +#' \item \code{from_json}: a structType object to use as the schema to use +#' when parsing the JSON string. Since Spark 2.3, the DDL-formatted string is +#' also supported for the schema. +#' \item \code{from_csv}: a DDL-formatted string +#' } #' @param ... additional argument(s). In \code{to_json} and \code{from_json}, this contains #' additional named properties to control how it is converted, accepts the same #' options as the JSON data source. Additionally \code{to_json} supports the "pretty" @@ -2165,8 +2173,6 @@ setMethod("date_format", signature(y = "Column", x = "character"), #' to \code{TRUE}. If the string is unparseable, the Column will contain the value NA. #' #' @rdname column_collection_functions -#' @param schema a structType object to use as the schema to use when parsing the JSON string. -#' Since Spark 2.3, the DDL-formatted string is also supported for the schema. #' @param as.json.array indicating if input string is JSON array of objects or a single object. #' @aliases from_json from_json,Column,characterOrstructType-method #' @examples @@ -2203,6 +2209,36 @@ setMethod("from_json", signature(x = "Column", schema = "characterOrstructType") column(jc) }) +#' @details +#' \code{from_csv}: Parses a column containing a CSV string into a Column of \code{structType} +#' with the specified \code{schema}. +#' If the string is unparseable, the Column will contain the value NA. +#' +#' @rdname column_collection_functions +#' @aliases from_csv from_csv,Column,character-method +#' @examples +#' +#' \dontrun{ +#' df <- sql("SELECT 'Amsterdam,2018' as csv") +#' schema <- "city STRING, year INT" +#' head(select(df, from_csv(df$csv, schema)))} +#' @note from_csv since 3.0.0 +setMethod("from_csv", signature(x = "Column", schema = "characterOrColumn"), + function(x, schema, ...) { + if (class(schema) == "Column") { + jschema <- schema@jc + } else if (is.character(schema)) { + jschema <- callJStatic("org.apache.spark.sql.functions", "lit", schema) + } else { + stop("schema argument should be a column or character") + } + options <- varargsToStrEnv(...) + jc <- callJStatic("org.apache.spark.sql.functions", + "from_csv", + x@jc, jschema, options) + column(jc) + }) + #' @details #' \code{from_utc_timestamp}: This is a common function for databases supporting TIMESTAMP WITHOUT #' TIMEZONE. This function takes a timestamp which is timezone-agnostic, and interprets it as a diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 697d124095a75..d501f73b0b7b9 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -984,6 +984,10 @@ setGeneric("format_string", function(format, x, ...) { standardGeneric("format_s #' @name NULL setGeneric("from_json", function(x, schema, ...) { standardGeneric("from_json") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("from_csv", function(x, schema, ...) { standardGeneric("from_csv") }) + #' @rdname column_datetime_functions #' @name NULL setGeneric("from_unixtime", function(x, ...) { standardGeneric("from_unixtime") }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 5cc75aa3f3673..5ad5d78d3ed17 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1647,6 +1647,13 @@ test_that("column functions", { expect_equal(collect(select(df, bround(df$x, 0)))[[1]][1], 2) expect_equal(collect(select(df, bround(df$x, 0)))[[1]][2], 4) + # Test from_csv() + df <- as.DataFrame(list(list("col" = "1"))) + c <- collect(select(df, alias(from_csv(df$col, "a INT"), "csv"))) + expect_equal(c[[1]][[1]]$a, 1) + c <- collect(select(df, alias(from_csv(df$col, lit("a INT")), "csv"))) + expect_equal(c[[1]][[1]]$a, 1) + # Test to_json(), from_json() df <- sql("SELECT array(named_struct('name', 'Bob'), named_struct('name', 'Alice')) as people") j <- collect(select(df, alias(to_json(df$people), "json"))) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 5425d311f8c7f..32d7f02f61883 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -25,9 +25,12 @@ if sys.version < "3": from itertools import imap as map +if sys.version >= '3': + basestring = str + from pyspark import since, SparkContext from pyspark.rdd import ignore_unicode_prefix, PythonEvalType -from pyspark.sql.column import Column, _to_java_column, _to_seq +from pyspark.sql.column import Column, _to_java_column, _to_seq, _create_column_from_literal from pyspark.sql.dataframe import DataFrame from pyspark.sql.types import StringType, DataType # Keep UserDefinedFunction import for backwards compatible import; moved in SPARK-22409 @@ -2678,6 +2681,38 @@ def sequence(start, stop, step=None): _to_java_column(start), _to_java_column(stop), _to_java_column(step))) +@ignore_unicode_prefix +@since(3.0) +def from_csv(col, schema, options={}): + """ + Parses a column containing a CSV string to a row with the specified schema. + Returns `null`, in the case of an unparseable string. + + :param col: string column in CSV format + :param schema: a string with schema in DDL format to use when parsing the CSV column. + :param options: options to control parsing. accepts the same options as the CSV datasource + + >>> data = [(1, '1')] + >>> df = spark.createDataFrame(data, ("key", "value")) + >>> df.select(from_csv(df.value, "a INT").alias("csv")).collect() + [Row(csv=Row(a=1))] + >>> df = spark.createDataFrame(data, ("key", "value")) + >>> df.select(from_csv(df.value, lit("a INT")).alias("csv")).collect() + [Row(csv=Row(a=1))] + """ + + sc = SparkContext._active_spark_context + if isinstance(schema, basestring): + schema = _create_column_from_literal(schema) + elif isinstance(schema, Column): + schema = _to_java_column(schema) + else: + raise TypeError("schema argument should be a column or string") + + jc = sc._jvm.functions.from_csv(_to_java_column(col), schema, options) + return Column(jc) + + # ---------------------------- User Defined Function ---------------------------------- class PandasUDFType(object): diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 2e7df4fd14042..16ecebf159c1f 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -103,6 +103,12 @@ commons-codec commons-codec + + com.univocity + univocity-parsers + 2.7.3 + jar + target/scala-${scala.binary.version}/classes diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 7dafebff79874..38f5c02910f79 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -520,7 +520,10 @@ object FunctionRegistry { castAlias("date", DateType), castAlias("timestamp", TimestampType), castAlias("binary", BinaryType), - castAlias("string", StringType) + castAlias("string", StringType), + + // csv + expression[CsvToStructs]("from_csv") ) val builtin: SimpleFunctionRegistry = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtils.scala new file mode 100644 index 0000000000000..bbe27831f01df --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtils.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.csv + +object CSVExprUtils { + /** + * Filter ignorable rows for CSV iterator (lines empty and starting with `comment`). + * This is currently being used in CSV reading path and CSV schema inference. + */ + def filterCommentAndEmpty(iter: Iterator[String], options: CSVOptions): Iterator[String] = { + iter.filter { line => + line.trim.nonEmpty && !line.startsWith(options.comment.toString) + } + } + + def skipComments(iter: Iterator[String], options: CSVOptions): Iterator[String] = { + if (options.isCommentSet) { + val commentPrefix = options.comment.toString + iter.dropWhile { line => + line.trim.isEmpty || line.trim.startsWith(commentPrefix) + } + } else { + iter.dropWhile(_.trim.isEmpty) + } + } + + /** + * Extracts header and moves iterator forward so that only data remains in it + */ + def extractHeader(iter: Iterator[String], options: CSVOptions): Option[String] = { + val nonEmptyLines = skipComments(iter, options) + if (nonEmptyLines.hasNext) { + Some(nonEmptyLines.next()) + } else { + None + } + } + + /** + * Helper method that converts string representation of a character to actual character. + * It handles some Java escaped strings and throws exception if given string is longer than one + * character. + */ + @throws[IllegalArgumentException] + def toChar(str: String): Char = { + (str: Seq[Char]) match { + case Seq() => throw new IllegalArgumentException("Delimiter cannot be empty string") + case Seq('\\') => throw new IllegalArgumentException("Single backslash is prohibited." + + " It has special meaning as beginning of an escape sequence." + + " To get the backslash character, pass a string with two backslashes as the delimiter.") + case Seq(c) => c + case Seq('\\', 't') => '\t' + case Seq('\\', 'r') => '\r' + case Seq('\\', 'b') => '\b' + case Seq('\\', 'f') => '\f' + // In case user changes quote char and uses \" as delimiter in options + case Seq('\\', '\"') => '\"' + case Seq('\\', '\'') => '\'' + case Seq('\\', '\\') => '\\' + case _ if str == """\u0000""" => '\u0000' + case Seq('\\', _) => + throw new IllegalArgumentException(s"Unsupported special character for delimiter: $str") + case _ => + throw new IllegalArgumentException(s"Delimiter cannot be more than one character: $str") + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVHeaderChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVHeaderChecker.scala similarity index 97% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVHeaderChecker.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVHeaderChecker.scala index 558ee91c419b9..c39f77e891ae1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVHeaderChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVHeaderChecker.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.datasources.csv +package org.apache.spark.sql.catalyst.csv import com.univocity.parsers.csv.CsvParser @@ -123,7 +123,7 @@ class CSVHeaderChecker( // Note: if there are only comments in the first block, the header would probably // be not extracted. if (options.headerFlag && isStartOfFile) { - CSVUtils.extractHeader(lines, options).foreach { header => + CSVExprUtils.extractHeader(lines, options).foreach { header => checkHeaderColumnNames(tokenizer.parseLine(header)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala index 492a21be6df3b..3e25d820e9941 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.datasources.csv +package org.apache.spark.sql.catalyst.csv import java.nio.charset.StandardCharsets import java.util.{Locale, TimeZone} @@ -83,7 +83,7 @@ class CSVOptions( } } - val delimiter = CSVUtils.toChar( + val delimiter = CSVExprUtils.toChar( parameters.getOrElse("sep", parameters.getOrElse("delimiter", ","))) val parseMode: ParseMode = parameters.get("mode").map(ParseMode.fromString).getOrElse(PermissiveMode) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala similarity index 97% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala index fbd19c6e677e5..46ed58ed92830 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.datasources.csv +package org.apache.spark.sql.catalyst.csv import java.io.InputStream import java.math.BigDecimal @@ -28,8 +28,7 @@ import com.univocity.parsers.csv.CsvParser import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow -import org.apache.spark.sql.catalyst.util.{BadRecordException, DateTimeUtils} -import org.apache.spark.sql.execution.datasources.FailureSafeParser +import org.apache.spark.sql.catalyst.util.{BadRecordException, DateTimeUtils, FailureSafeParser} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -264,7 +263,7 @@ class UnivocityParser( } } -private[csv] object UnivocityParser { +private[sql] object UnivocityParser { /** * Parses a stream that contains CSV strings and turns it into an iterator of tokens. @@ -339,7 +338,7 @@ private[csv] object UnivocityParser { val options = parser.options - val filteredLines: Iterator[String] = CSVUtils.filterCommentAndEmpty(lines, options) + val filteredLines: Iterator[String] = CSVExprUtils.filterCommentAndEmpty(lines, options) val safeParser = new FailureSafeParser[String]( input => Seq(parser.parse(input)), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala new file mode 100644 index 0000000000000..e5708894f22b4 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.util.ArrayBasedMapData +import org.apache.spark.sql.types.{MapType, StringType, StructType} + +object ExprUtils { + + def evalSchemaExpr(exp: Expression): StructType = exp match { + case Literal(s, StringType) => StructType.fromDDL(s.toString) + case e => throw new AnalysisException( + s"Schema should be specified in DDL format as a string literal instead of ${e.sql}") + } + + def convertToMapData(exp: Expression): Map[String, String] = exp match { + case m: CreateMap + if m.dataType.acceptsType(MapType(StringType, StringType, valueContainsNull = false)) => + val arrayMap = m.eval().asInstanceOf[ArrayBasedMapData] + ArrayBasedMapData.toScalaMap(arrayMap).map { case (key, value) => + key.toString -> value.toString + } + case m: CreateMap => + throw new AnalysisException( + s"A type of keys and values in map() must be string, but got ${m.dataType.catalogString}") + case _ => + throw new AnalysisException("Must use a map() function for options") + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala new file mode 100644 index 0000000000000..a63b6245c499e --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.csv._ +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +/** + * Converts a CSV input string to a [[StructType]] with the specified schema. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(csvStr, schema[, options]) - Returns a struct value with the given `csvStr` and `schema`.", + examples = """ + Examples: + > SELECT _FUNC_('1, 0.8', 'a INT, b DOUBLE'); + {"a":1, "b":0.8} + > SELECT _FUNC_('26/08/2015', 'time Timestamp', map('timestampFormat', 'dd/MM/yyyy')) + {"time":2015-08-26 00:00:00.0} + """, + since = "3.0.0") +// scalastyle:on line.size.limit +case class CsvToStructs( + schema: StructType, + options: Map[String, String], + child: Expression, + timeZoneId: Option[String] = None) + extends UnaryExpression + with TimeZoneAwareExpression + with CodegenFallback + with ExpectsInputTypes + with NullIntolerant { + + override def nullable: Boolean = child.nullable + + // The CSV input data might be missing certain fields. We force the nullability + // of the user-provided schema to avoid data corruptions. + val nullableSchema: StructType = schema.asNullable + + // Used in `FunctionRegistry` + def this(child: Expression, schema: Expression, options: Map[String, String]) = + this( + schema = ExprUtils.evalSchemaExpr(schema), + options = options, + child = child, + timeZoneId = None) + + def this(child: Expression, schema: Expression) = this(child, schema, Map.empty[String, String]) + + def this(child: Expression, schema: Expression, options: Expression) = + this( + schema = ExprUtils.evalSchemaExpr(schema), + options = ExprUtils.convertToMapData(options), + child = child, + timeZoneId = None) + + // This converts parsed rows to the desired output by the given schema. + @transient + lazy val converter = (rows: Iterator[InternalRow]) => { + if (rows.hasNext) { + val result = rows.next() + // CSV's parser produces one record only. + assert(!rows.hasNext) + result + } else { + throw new IllegalArgumentException("Expected one row from CSV parser.") + } + } + + @transient lazy val parser = { + val parsedOptions = new CSVOptions(options, columnPruning = true, timeZoneId.get) + val mode = parsedOptions.parseMode + if (mode != PermissiveMode && mode != FailFastMode) { + throw new AnalysisException(s"from_csv() doesn't support the ${mode.name} mode. " + + s"Acceptable modes are ${PermissiveMode.name} and ${FailFastMode.name}.") + } + val actualSchema = + StructType(nullableSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) + val rawParser = new UnivocityParser(actualSchema, actualSchema, parsedOptions) + new FailureSafeParser[String]( + input => Seq(rawParser.parse(input)), + mode, + nullableSchema, + parsedOptions.columnNameOfCorruptRecord, + parsedOptions.multiLine) + } + + override def dataType: DataType = nullableSchema + + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = { + copy(timeZoneId = Option(timeZoneId)) + } + + override def nullSafeEval(input: Any): Any = { + val csv = input.asInstanceOf[UTF8String].toString + converter(parser.parse(csv)) + } + + override def inputTypes: Seq[AbstractDataType] = StringType :: Nil +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index f5297dde10ed6..9f2848365bf40 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -539,7 +539,7 @@ case class JsonToStructs( def this(child: Expression, schema: Expression, options: Expression) = this( schema = JsonExprUtils.evalSchemaExpr(schema), - options = JsonExprUtils.convertToMapData(options), + options = ExprUtils.convertToMapData(options), child = child, timeZoneId = None) @@ -650,7 +650,7 @@ case class StructsToJson( def this(child: Expression) = this(Map.empty, child, None) def this(child: Expression, options: Expression) = this( - options = JsonExprUtils.convertToMapData(options), + options = ExprUtils.convertToMapData(options), child = child, timeZoneId = None) @@ -754,7 +754,7 @@ case class SchemaOfJson( def this(child: Expression, options: Expression) = this( child = child, - options = JsonExprUtils.convertToMapData(options)) + options = ExprUtils.convertToMapData(options)) @transient private lazy val jsonOptions = new JSONOptions(options, "UTC") @@ -777,7 +777,6 @@ case class SchemaOfJson( } object JsonExprUtils { - def evalSchemaExpr(exp: Expression): DataType = exp match { case Literal(s, StringType) => DataType.fromDDL(s.toString) case e @ SchemaOfJson(_: Literal, _) => @@ -787,18 +786,4 @@ object JsonExprUtils { "Schema should be specified in DDL format as a string literal" + s" or output of the schema_of_json function instead of ${e.sql}") } - - def convertToMapData(exp: Expression): Map[String, String] = exp match { - case m: CreateMap - if m.dataType.acceptsType(MapType(StringType, StringType, valueContainsNull = false)) => - val arrayMap = m.eval().asInstanceOf[ArrayBasedMapData] - ArrayBasedMapData.toScalaMap(arrayMap).map { case (key, value) => - key.toString -> value.toString - } - case m: CreateMap => - throw new AnalysisException( - s"A type of keys and values in map() must be string, but got ${m.dataType.catalogString}") - case _ => - throw new AnalysisException("Must use a map() function for options") - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala similarity index 95% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala index 90e81661bae7a..fecfff5789a5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala @@ -15,13 +15,11 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.datasources +package org.apache.spark.sql.catalyst.util import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow -import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtilsSuite.scala similarity index 72% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtilsSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtilsSuite.scala index 60fcbd2ff008c..838ac42184fa5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtilsSuite.scala @@ -15,46 +15,46 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.datasources.csv +package org.apache.spark.sql.catalyst.csv import org.apache.spark.SparkFunSuite -class CSVUtilsSuite extends SparkFunSuite { +class CSVExprUtilsSuite extends SparkFunSuite { test("Can parse escaped characters") { - assert(CSVUtils.toChar("""\t""") === '\t') - assert(CSVUtils.toChar("""\r""") === '\r') - assert(CSVUtils.toChar("""\b""") === '\b') - assert(CSVUtils.toChar("""\f""") === '\f') - assert(CSVUtils.toChar("""\"""") === '\"') - assert(CSVUtils.toChar("""\'""") === '\'') - assert(CSVUtils.toChar("""\u0000""") === '\u0000') - assert(CSVUtils.toChar("""\\""") === '\\') + assert(CSVExprUtils.toChar("""\t""") === '\t') + assert(CSVExprUtils.toChar("""\r""") === '\r') + assert(CSVExprUtils.toChar("""\b""") === '\b') + assert(CSVExprUtils.toChar("""\f""") === '\f') + assert(CSVExprUtils.toChar("""\"""") === '\"') + assert(CSVExprUtils.toChar("""\'""") === '\'') + assert(CSVExprUtils.toChar("""\u0000""") === '\u0000') + assert(CSVExprUtils.toChar("""\\""") === '\\') } test("Does not accept delimiter larger than one character") { val exception = intercept[IllegalArgumentException]{ - CSVUtils.toChar("ab") + CSVExprUtils.toChar("ab") } assert(exception.getMessage.contains("cannot be more than one character")) } test("Throws exception for unsupported escaped characters") { val exception = intercept[IllegalArgumentException]{ - CSVUtils.toChar("""\1""") + CSVExprUtils.toChar("""\1""") } assert(exception.getMessage.contains("Unsupported special character for delimiter")) } test("string with one backward slash is prohibited") { val exception = intercept[IllegalArgumentException]{ - CSVUtils.toChar("""\""") + CSVExprUtils.toChar("""\""") } assert(exception.getMessage.contains("Single backslash is prohibited")) } test("output proper error message for empty string") { val exception = intercept[IllegalArgumentException]{ - CSVUtils.toChar("") + CSVExprUtils.toChar("") } assert(exception.getMessage.contains("Delimiter cannot be empty string")) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala new file mode 100644 index 0000000000000..65987af710750 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.util.Calendar + +import org.scalatest.exceptions.TestFailedException + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.plans.PlanTestBase +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +class CsvExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with PlanTestBase { + val badCsv = "\u0000\u0000\u0000A\u0001AAA" + + val gmtId = Option(DateTimeUtils.TimeZoneGMT.getID) + + test("from_csv") { + val csvData = "1" + val schema = StructType(StructField("a", IntegerType) :: Nil) + checkEvaluation( + CsvToStructs(schema, Map.empty, Literal(csvData), gmtId), + InternalRow(1) + ) + } + + test("from_csv - invalid data") { + val csvData = "---" + val schema = StructType(StructField("a", DoubleType) :: Nil) + checkEvaluation( + CsvToStructs(schema, Map("mode" -> PermissiveMode.name), Literal(csvData), gmtId), + InternalRow(null)) + + // Default mode is Permissive + checkEvaluation(CsvToStructs(schema, Map.empty, Literal(csvData), gmtId), InternalRow(null)) + } + + test("from_csv null input column") { + val schema = StructType(StructField("a", IntegerType) :: Nil) + checkEvaluation( + CsvToStructs(schema, Map.empty, Literal.create(null, StringType), gmtId), + null + ) + } + + test("from_csv bad UTF-8") { + val schema = StructType(StructField("a", IntegerType) :: Nil) + checkEvaluation( + CsvToStructs(schema, Map.empty, Literal(badCsv), gmtId), + InternalRow(null)) + } + + test("from_csv with timestamp") { + val schema = StructType(StructField("t", TimestampType) :: Nil) + + val csvData1 = "2016-01-01T00:00:00.123Z" + var c = Calendar.getInstance(DateTimeUtils.TimeZoneGMT) + c.set(2016, 0, 1, 0, 0, 0) + c.set(Calendar.MILLISECOND, 123) + checkEvaluation( + CsvToStructs(schema, Map.empty, Literal(csvData1), gmtId), + InternalRow(c.getTimeInMillis * 1000L) + ) + // The result doesn't change because the CSV string includes timezone string ("Z" here), + // which means the string represents the timestamp string in the timezone regardless of + // the timeZoneId parameter. + checkEvaluation( + CsvToStructs(schema, Map.empty, Literal(csvData1), Option("PST")), + InternalRow(c.getTimeInMillis * 1000L) + ) + + val csvData2 = "2016-01-01T00:00:00" + for (tz <- DateTimeTestUtils.outstandingTimezones) { + c = Calendar.getInstance(tz) + c.set(2016, 0, 1, 0, 0, 0) + c.set(Calendar.MILLISECOND, 0) + checkEvaluation( + CsvToStructs( + schema, + Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss"), + Literal(csvData2), + Option(tz.getID)), + InternalRow(c.getTimeInMillis * 1000L) + ) + checkEvaluation( + CsvToStructs( + schema, + Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss", + DateTimeUtils.TIMEZONE_OPTION -> tz.getID), + Literal(csvData2), + gmtId), + InternalRow(c.getTimeInMillis * 1000L) + ) + } + } + + test("from_csv empty input column") { + val schema = StructType(StructField("a", IntegerType) :: Nil) + checkEvaluation( + CsvToStructs(schema, Map.empty, Literal.create(" ", StringType), gmtId), + InternalRow(null) + ) + } + + test("forcing schema nullability") { + val input = """1,,"foo"""" + val csvSchema = new StructType() + .add("a", LongType, nullable = false) + .add("b", StringType, nullable = false) + .add("c", StringType, nullable = false) + val output = InternalRow(1L, null, UTF8String.fromString("foo")) + val expr = CsvToStructs(csvSchema, Map.empty, Literal.create(input, StringType), gmtId) + checkEvaluation(expr, output) + val schema = expr.dataType + val schemaToCompare = csvSchema.asNullable + assert(schemaToCompare == schema) + } + + + test("from_csv missing columns") { + val schema = new StructType() + .add("a", IntegerType) + .add("b", IntegerType) + checkEvaluation( + CsvToStructs(schema, Map.empty, Literal.create("1"), gmtId), + InternalRow(1, null) + ) + } + + test("unsupported mode") { + val csvData = "---" + val schema = StructType(StructField("a", DoubleType) :: Nil) + val exception = intercept[TestFailedException] { + checkEvaluation( + CsvToStructs(schema, Map("mode" -> DropMalformedMode.name), Literal(csvData), gmtId), + InternalRow(null)) + }.getCause + assert(exception.getMessage.contains("from_csv() doesn't support the DROPMALFORMED mode")) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 3af70b5153c83..4f6d8b8a0c34a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -22,16 +22,17 @@ import java.util.{Locale, Properties} import scala.collection.JavaConverters._ import com.fasterxml.jackson.databind.ObjectMapper -import com.univocity.parsers.csv.CsvParser import org.apache.spark.Partition import org.apache.spark.annotation.InterfaceStability import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, UnivocityParser} import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} +import org.apache.spark.sql.catalyst.util.FailureSafeParser import org.apache.spark.sql.execution.command.DDLUtils -import org.apache.spark.sql.execution.datasources.{DataSource, FailureSafeParser} +import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.datasources.csv._ import org.apache.spark.sql.execution.datasources.jdbc._ import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 0b5a719d427c9..9e7b45db9f280 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -34,6 +34,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, UnivocityParser} import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.types.StructType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index 3de1c2d955d20..954a5a9cdecbb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -26,6 +26,7 @@ import org.apache.hadoop.mapreduce._ import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, UnivocityParser} import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala index 3596ff105fd7f..4326a186d6d5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala @@ -23,6 +23,7 @@ import scala.util.control.Exception._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.TypeCoercion +import org.apache.spark.sql.catalyst.csv.CSVOptions import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala index 0a7473c491b12..21fabac472f4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.datasources.csv import org.apache.spark.rdd.RDD import org.apache.spark.sql.Dataset +import org.apache.spark.sql.catalyst.csv.CSVExprUtils +import org.apache.spark.sql.catalyst.csv.CSVOptions import org.apache.spark.sql.functions._ object CSVUtils { @@ -39,16 +41,6 @@ object CSVUtils { } } - /** - * Filter ignorable rows for CSV iterator (lines empty and starting with `comment`). - * This is currently being used in CSV reading path and CSV schema inference. - */ - def filterCommentAndEmpty(iter: Iterator[String], options: CSVOptions): Iterator[String] = { - iter.filter { line => - line.trim.nonEmpty && !line.startsWith(options.comment.toString) - } - } - /** * Skip the given first line so that only data can remain in a dataset. * This is similar with `dropHeaderLine` below and currently being used in CSV schema inference. @@ -67,29 +59,6 @@ object CSVUtils { } } - def skipComments(iter: Iterator[String], options: CSVOptions): Iterator[String] = { - if (options.isCommentSet) { - val commentPrefix = options.comment.toString - iter.dropWhile { line => - line.trim.isEmpty || line.trim.startsWith(commentPrefix) - } - } else { - iter.dropWhile(_.trim.isEmpty) - } - } - - /** - * Extracts header and moves iterator forward so that only data remains in it - */ - def extractHeader(iter: Iterator[String], options: CSVOptions): Option[String] = { - val nonEmptyLines = skipComments(iter, options) - if (nonEmptyLines.hasNext) { - Some(nonEmptyLines.next()) - } else { - None - } - } - /** * Generates a header from the given row which is null-safe and duplicate-safe. */ @@ -132,35 +101,6 @@ object CSVUtils { } } - /** - * Helper method that converts string representation of a character to actual character. - * It handles some Java escaped strings and throws exception if given string is longer than one - * character. - */ - @throws[IllegalArgumentException] - def toChar(str: String): Char = { - (str: Seq[Char]) match { - case Seq() => throw new IllegalArgumentException("Delimiter cannot be empty string") - case Seq('\\') => throw new IllegalArgumentException("Single backslash is prohibited." + - " It has special meaning as beginning of an escape sequence." + - " To get the backslash character, pass a string with two backslashes as the delimiter.") - case Seq(c) => c - case Seq('\\', 't') => '\t' - case Seq('\\', 'r') => '\r' - case Seq('\\', 'b') => '\b' - case Seq('\\', 'f') => '\f' - // In case user changes quote char and uses \" as delimiter in options - case Seq('\\', '\"') => '\"' - case Seq('\\', '\'') => '\'' - case Seq('\\', '\\') => '\\' - case _ if str == """\u0000""" => '\u0000' - case Seq('\\', _) => - throw new IllegalArgumentException(s"Unsupported special character for delimiter: $str") - case _ => - throw new IllegalArgumentException(s"Delimiter cannot be more than one character: $str") - } - } - /** * Sample CSV dataset as configured by `samplingRatio`. */ @@ -186,4 +126,7 @@ object CSVUtils { csv.sample(withReplacement = false, options.samplingRatio, 1) } } + + def filterCommentAndEmpty(iter: Iterator[String], options: CSVOptions): Iterator[String] = + CSVExprUtils.filterCommentAndEmpty(iter, options) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityGenerator.scala index 4082a0df8ba75..37d9d9abc8680 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityGenerator.scala @@ -22,6 +22,7 @@ import java.io.Writer import com.univocity.parsers.csv.CsvWriter import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.csv.CSVOptions import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 76f58371ae264..c7608e2e881ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -34,6 +34,7 @@ import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JsonInferSchema, JSONOptions} +import org.apache.spark.sql.catalyst.util.FailureSafeParser import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextFileFormat diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 4247d3110f1e1..8def9967cffb1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3854,6 +3854,38 @@ object functions { @scala.annotation.varargs def map_concat(cols: Column*): Column = withExpr { MapConcat(cols.map(_.expr)) } + /** + * Parses a column containing a CSV string into a `StructType` with the specified schema. + * Returns `null`, in the case of an unparseable string. + * + * @param e a string column containing CSV data. + * @param schema the schema to use when parsing the CSV string + * @param options options to control how the CSV is parsed. accepts the same options and the + * CSV data source. + * + * @group collection_funcs + * @since 3.0.0 + */ + def from_csv(e: Column, schema: StructType, options: Map[String, String]): Column = withExpr { + CsvToStructs(schema, options, e.expr) + } + + /** + * (Java-specific) Parses a column containing a CSV string into a `StructType` + * with the specified schema. Returns `null`, in the case of an unparseable string. + * + * @param e a string column containing CSV data. + * @param schema the schema to use when parsing the CSV string + * @param options options to control how the CSV is parsed. accepts the same options and the + * CSV data source. + * + * @group collection_funcs + * @since 3.0.0 + */ + def from_csv(e: Column, schema: Column, options: java.util.Map[String, String]): Column = { + withExpr(new CsvToStructs(e.expr, schema.expr, options.asScala.toMap)) + } + // scalastyle:off line.size.limit // scalastyle:off parameter.number diff --git a/sql/core/src/test/resources/sql-tests/inputs/csv-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/csv-functions.sql new file mode 100644 index 0000000000000..d2214fd016028 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/csv-functions.sql @@ -0,0 +1,9 @@ +-- from_csv +select from_csv('1, 3.14', 'a INT, f FLOAT'); +select from_csv('26/08/2015', 'time Timestamp', map('timestampFormat', 'dd/MM/yyyy')); +-- Check if errors handled +select from_csv('1', 1); +select from_csv('1', 'a InvalidType'); +select from_csv('1', 'a INT', named_struct('mode', 'PERMISSIVE')); +select from_csv('1', 'a INT', map('mode', 1)); +select from_csv(); diff --git a/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out new file mode 100644 index 0000000000000..15dbe36bc0f6a --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out @@ -0,0 +1,69 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 7 + + +-- !query 0 +select from_csv('1, 3.14', 'a INT, f FLOAT') +-- !query 0 schema +struct> +-- !query 0 output +{"a":1,"f":3.14} + + +-- !query 1 +select from_csv('26/08/2015', 'time Timestamp', map('timestampFormat', 'dd/MM/yyyy')) +-- !query 1 schema +struct> +-- !query 1 output +{"time":2015-08-26 00:00:00.0} + + +-- !query 2 +select from_csv('1', 1) +-- !query 2 schema +struct<> +-- !query 2 output +org.apache.spark.sql.AnalysisException +Schema should be specified in DDL format as a string literal instead of 1;; line 1 pos 7 + + +-- !query 3 +select from_csv('1', 'a InvalidType') +-- !query 3 schema +struct<> +-- !query 3 output +org.apache.spark.sql.AnalysisException + +DataType invalidtype is not supported.(line 1, pos 2) + +== SQL == +a InvalidType +--^^^ +; line 1 pos 7 + + +-- !query 4 +select from_csv('1', 'a INT', named_struct('mode', 'PERMISSIVE')) +-- !query 4 schema +struct<> +-- !query 4 output +org.apache.spark.sql.AnalysisException +Must use a map() function for options;; line 1 pos 7 + + +-- !query 5 +select from_csv('1', 'a INT', map('mode', 1)) +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +A type of keys and values in map() must be string, but got map;; line 1 pos 7 + + +-- !query 6 +select from_csv() +-- !query 6 schema +struct<> +-- !query 6 output +org.apache.spark.sql.AnalysisException +Invalid number of arguments for function from_csv. Expected: one of 2 and 3; Found: 0; line 1 pos 7 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala new file mode 100644 index 0000000000000..38a2143d6d0f0 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ + +class CsvFunctionsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("from_csv with empty options") { + val df = Seq("1").toDS() + val schema = "a int" + + checkAnswer( + df.select(from_csv($"value", lit(schema), Map[String, String]().asJava)), + Row(Row(1)) :: Nil) + } + + test("from_csv with option") { + val df = Seq("26/08/2015 18:00").toDS() + val schema = new StructType().add("time", TimestampType) + val options = Map("timestampFormat" -> "dd/MM/yyyy HH:mm") + + checkAnswer( + df.select(from_csv($"value", schema, options)), + Row(Row(java.sql.Timestamp.valueOf("2015-08-26 18:00:00.0")))) + } + + + test("checking the columnNameOfCorruptRecord option") { + val columnNameOfCorruptRecord = "_unparsed" + val df = Seq("0,2013-111-11 12:13:14", "1,1983-08-04").toDS() + val schema = new StructType().add("a", IntegerType).add("b", TimestampType) + val schemaWithCorrField1 = schema.add(columnNameOfCorruptRecord, StringType) + val df2 = df + .select(from_csv($"value", schemaWithCorrField1, Map( + "mode" -> "Permissive", "columnNameOfCorruptRecord" -> columnNameOfCorruptRecord))) + + checkAnswer(df2, Seq( + Row(Row(null, null, "0,2013-111-11 12:13:14")), + Row(Row(1, java.sql.Date.valueOf("1983-08-04"), null)))) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala index 57e36e082653c..6b64f2ffa98dd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.csv import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.csv.CSVOptions import org.apache.spark.sql.types._ class CSVInferSchemaSuite extends SparkFunSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala index 458edb253fb33..6f231142949d1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala @@ -18,9 +18,9 @@ package org.apache.spark.sql.execution.datasources.csv import java.math.BigDecimal -import java.util.Locale import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.csv.{CSVOptions, UnivocityParser} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String