From 57aacdd2492202f876d8dfd80daf33949a5a859e Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sun, 7 Oct 2018 16:37:50 +0200 Subject: [PATCH 01/26] Moving CSVOptions to sql/catalyst --- .../spark/sql/catalyst/csv/CSVUtils.scala | 47 +++++++++++++++++++ .../sql/catalyst/csv/UnivocityParser.scala | 1 + .../sql/catalyst/csv/CSVUtilsSuite.scala | 47 +++++++++++++++++++ .../execution/datasources/csv/CSVUtils.scala | 42 ----------------- 4 files changed, 95 insertions(+), 42 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVUtils.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVUtilsSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVUtils.scala new file mode 100644 index 0000000000000..9e3fc701b4480 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVUtils.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.csv + +object CSVUtils { + /** + * Helper method that converts string representation of a character to actual character. + * It handles some Java escaped strings and throws exception if given string is longer than one + * character. + */ + @throws[IllegalArgumentException] + def toChar(str: String): Char = { + if (str.charAt(0) == '\\') { + str.charAt(1) + match { + case 't' => '\t' + case 'r' => '\r' + case 'b' => '\b' + case 'f' => '\f' + case '\"' => '\"' // In case user changes quote char and uses \" as delimiter in options + case '\'' => '\'' + case 'u' if str == """\u0000""" => '\u0000' + case _ => + throw new IllegalArgumentException(s"Unsupported special character for delimiter: $str") + } + } else if (str.length == 1) { + str.charAt(0) + } else { + throw new IllegalArgumentException(s"Delimiter cannot be more than one character: $str") + } + } +} \ No newline at end of file diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala index 46ed58ed92830..8669610958e68 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala @@ -27,6 +27,7 @@ import com.univocity.parsers.csv.CsvParser import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.csv.CSVOptions import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.catalyst.util.{BadRecordException, DateTimeUtils, FailureSafeParser} import org.apache.spark.sql.types._ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVUtilsSuite.scala new file mode 100644 index 0000000000000..3217df9aed335 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVUtilsSuite.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.csv + +import org.apache.spark.SparkFunSuite + +class CSVUtilsSuite extends SparkFunSuite { + test("Can parse escaped characters") { + assert(CSVUtils.toChar("""\t""") === '\t') + assert(CSVUtils.toChar("""\r""") === '\r') + assert(CSVUtils.toChar("""\b""") === '\b') + assert(CSVUtils.toChar("""\f""") === '\f') + assert(CSVUtils.toChar("""\"""") === '\"') + assert(CSVUtils.toChar("""\'""") === '\'') + assert(CSVUtils.toChar("""\u0000""") === '\u0000') + } + + test("Does not accept delimiter larger than one character") { + val exception = intercept[IllegalArgumentException]{ + CSVUtils.toChar("ab") + } + assert(exception.getMessage.contains("cannot be more than one character")) + } + + test("Throws exception for unsupported escaped characters") { + val exception = intercept[IllegalArgumentException]{ + CSVUtils.toChar("""\1""") + } + assert(exception.getMessage.contains("Unsupported special character for delimiter")) + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala index 21fabac472f4b..5bd19ce7f77ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala @@ -59,48 +59,6 @@ object CSVUtils { } } - /** - * Generates a header from the given row which is null-safe and duplicate-safe. - */ - def makeSafeHeader( - row: Array[String], - caseSensitive: Boolean, - options: CSVOptions): Array[String] = { - if (options.headerFlag) { - val duplicates = { - val headerNames = row.filter(_ != null) - // scalastyle:off caselocale - .map(name => if (caseSensitive) name else name.toLowerCase) - // scalastyle:on caselocale - headerNames.diff(headerNames.distinct).distinct - } - - row.zipWithIndex.map { case (value, index) => - if (value == null || value.isEmpty || value == options.nullValue) { - // When there are empty strings or the values set in `nullValue`, put the - // index as the suffix. - s"_c$index" - // scalastyle:off caselocale - } else if (!caseSensitive && duplicates.contains(value.toLowerCase)) { - // scalastyle:on caselocale - // When there are case-insensitive duplicates, put the index as the suffix. - s"$value$index" - } else if (duplicates.contains(value)) { - // When there are duplicates, put the index as the suffix. - s"$value$index" - } else { - value - } - } - } else { - row.zipWithIndex.map { case (_, index) => - // Uses default column names, "_c#" where # is its position of fields - // when header option is disabled. - s"_c$index" - } - } - } - /** * Sample CSV dataset as configured by `samplingRatio`. */ From 4f1f25aefff531a9bb71ac3e3818484f5da7ccb2 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sun, 7 Oct 2018 19:42:02 +0200 Subject: [PATCH 02/26] Moving CSVInferSchema to sql/catalyst --- .../sql/catalyst}/csv/CSVInferSchema.scala | 5 +- .../spark/sql/catalyst/csv/CSVUtils.scala | 2 +- .../catalyst}/csv/CSVInferSchemaSuite.scala | 3 +- .../csv/UnivocityParserSuite.scala | 200 ------------------ 4 files changed, 4 insertions(+), 206 deletions(-) rename sql/{core/src/main/scala/org/apache/spark/sql/execution/datasources => catalyst/src/main/scala/org/apache/spark/sql/catalyst}/csv/CSVInferSchema.scala (98%) rename sql/{core/src/test/scala/org/apache/spark/sql/execution/datasources => catalyst/src/test/scala/org/apache/spark/sql/catalyst}/csv/CSVInferSchemaSuite.scala (98%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala index 4326a186d6d5f..20883f7a95160 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.datasources.csv +package org.apache.spark.sql.catalyst.csv import java.math.BigDecimal @@ -23,11 +23,10 @@ import scala.util.control.Exception._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.TypeCoercion -import org.apache.spark.sql.catalyst.csv.CSVOptions import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ -private[csv] object CSVInferSchema { +private[sql] object CSVInferSchema { /** * Similar to the JSON schema inference diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVUtils.scala index 9e3fc701b4480..a0401c6f4ff03 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVUtils.scala @@ -44,4 +44,4 @@ object CSVUtils { throw new IllegalArgumentException(s"Delimiter cannot be more than one character: $str") } } -} \ No newline at end of file +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala similarity index 98% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala index 6b64f2ffa98dd..651846d2ebcb5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala @@ -15,10 +15,9 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.datasources.csv +package org.apache.spark.sql.catalyst.csv import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.csv.CSVOptions import org.apache.spark.sql.types._ class CSVInferSchemaSuite extends SparkFunSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala index 6f231142949d1..e69de29bb2d1d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala @@ -1,200 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.csv - -import java.math.BigDecimal - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.csv.{CSVOptions, UnivocityParser} -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String - -class UnivocityParserSuite extends SparkFunSuite { - private val parser = new UnivocityParser( - StructType(Seq.empty), - new CSVOptions(Map.empty[String, String], false, "GMT")) - - private def assertNull(v: Any) = assert(v == null) - - test("Can parse decimal type values") { - val stringValues = Seq("10.05", "1,000.01", "158,058,049.001") - val decimalValues = Seq(10.05, 1000.01, 158058049.001) - val decimalType = new DecimalType() - - stringValues.zip(decimalValues).foreach { case (strVal, decimalVal) => - val decimalValue = new BigDecimal(decimalVal.toString) - val options = new CSVOptions(Map.empty[String, String], false, "GMT") - assert(parser.makeConverter("_1", decimalType, options = options).apply(strVal) === - Decimal(decimalValue, decimalType.precision, decimalType.scale)) - } - } - - test("Nullable types are handled") { - val types = Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, - BooleanType, DecimalType.DoubleDecimal, TimestampType, DateType, StringType) - - // Nullable field with nullValue option. - types.foreach { t => - // Tests that a custom nullValue. - val nullValueOptions = new CSVOptions(Map("nullValue" -> "-"), false, "GMT") - val converter = - parser.makeConverter("_1", t, nullable = true, options = nullValueOptions) - assertNull(converter.apply("-")) - assertNull(converter.apply(null)) - - // Tests that the default nullValue is empty string. - val options = new CSVOptions(Map.empty[String, String], false, "GMT") - assertNull(parser.makeConverter("_1", t, nullable = true, options = options).apply("")) - } - - // Not nullable field with nullValue option. - types.foreach { t => - // Casts a null to not nullable field should throw an exception. - val options = new CSVOptions(Map("nullValue" -> "-"), false, "GMT") - val converter = - parser.makeConverter("_1", t, nullable = false, options = options) - var message = intercept[RuntimeException] { - converter.apply("-") - }.getMessage - assert(message.contains("null value found but field _1 is not nullable.")) - message = intercept[RuntimeException] { - converter.apply(null) - }.getMessage - assert(message.contains("null value found but field _1 is not nullable.")) - } - - // If nullValue is different with empty string, then, empty string should not be casted into - // null. - Seq(true, false).foreach { b => - val options = new CSVOptions(Map("nullValue" -> "null"), false, "GMT") - val converter = - parser.makeConverter("_1", StringType, nullable = b, options = options) - assert(converter.apply("") == UTF8String.fromString("")) - } - } - - test("Throws exception for empty string with non null type") { - val options = new CSVOptions(Map.empty[String, String], false, "GMT") - val exception = intercept[RuntimeException]{ - parser.makeConverter("_1", IntegerType, nullable = false, options = options).apply("") - } - assert(exception.getMessage.contains("null value found but field _1 is not nullable.")) - } - - test("Types are cast correctly") { - val options = new CSVOptions(Map.empty[String, String], false, "GMT") - assert(parser.makeConverter("_1", ByteType, options = options).apply("10") == 10) - assert(parser.makeConverter("_1", ShortType, options = options).apply("10") == 10) - assert(parser.makeConverter("_1", IntegerType, options = options).apply("10") == 10) - assert(parser.makeConverter("_1", LongType, options = options).apply("10") == 10) - assert(parser.makeConverter("_1", FloatType, options = options).apply("1.00") == 1.0) - assert(parser.makeConverter("_1", DoubleType, options = options).apply("1.00") == 1.0) - assert(parser.makeConverter("_1", BooleanType, options = options).apply("true") == true) - - val timestampsOptions = - new CSVOptions(Map("timestampFormat" -> "dd/MM/yyyy hh:mm"), false, "GMT") - val customTimestamp = "31/01/2015 00:00" - val expectedTime = timestampsOptions.timestampFormat.parse(customTimestamp).getTime - val castedTimestamp = - parser.makeConverter("_1", TimestampType, nullable = true, options = timestampsOptions) - .apply(customTimestamp) - assert(castedTimestamp == expectedTime * 1000L) - - val customDate = "31/01/2015" - val dateOptions = new CSVOptions(Map("dateFormat" -> "dd/MM/yyyy"), false, "GMT") - val expectedDate = dateOptions.dateFormat.parse(customDate).getTime - val castedDate = - parser.makeConverter("_1", DateType, nullable = true, options = dateOptions) - .apply(customTimestamp) - assert(castedDate == DateTimeUtils.millisToDays(expectedDate)) - - val timestamp = "2015-01-01 00:00:00" - assert(parser.makeConverter("_1", TimestampType, options = options).apply(timestamp) == - DateTimeUtils.stringToTime(timestamp).getTime * 1000L) - assert(parser.makeConverter("_1", DateType, options = options).apply("2015-01-01") == - DateTimeUtils.millisToDays(DateTimeUtils.stringToTime("2015-01-01").getTime)) - } - - test("Throws exception for casting an invalid string to Float and Double Types") { - val options = new CSVOptions(Map.empty[String, String], false, "GMT") - val types = Seq(DoubleType, FloatType) - val input = Seq("10u000", "abc", "1 2/3") - types.foreach { dt => - input.foreach { v => - val message = intercept[NumberFormatException] { - parser.makeConverter("_1", dt, options = options).apply(v) - }.getMessage - assert(message.contains(v)) - } - } - } - - test("Float NaN values are parsed correctly") { - val options = new CSVOptions(Map("nanValue" -> "nn"), false, "GMT") - val floatVal: Float = parser.makeConverter( - "_1", FloatType, nullable = true, options = options - ).apply("nn").asInstanceOf[Float] - - // Java implements the IEEE-754 floating point standard which guarantees that any comparison - // against NaN will return false (except != which returns true) - assert(floatVal != floatVal) - } - - test("Double NaN values are parsed correctly") { - val options = new CSVOptions(Map("nanValue" -> "-"), false, "GMT") - val doubleVal: Double = parser.makeConverter( - "_1", DoubleType, nullable = true, options = options - ).apply("-").asInstanceOf[Double] - - assert(doubleVal.isNaN) - } - - test("Float infinite values can be parsed") { - val negativeInfOptions = new CSVOptions(Map("negativeInf" -> "max"), false, "GMT") - val floatVal1 = parser.makeConverter( - "_1", FloatType, nullable = true, options = negativeInfOptions - ).apply("max").asInstanceOf[Float] - - assert(floatVal1 == Float.NegativeInfinity) - - val positiveInfOptions = new CSVOptions(Map("positiveInf" -> "max"), false, "GMT") - val floatVal2 = parser.makeConverter( - "_1", FloatType, nullable = true, options = positiveInfOptions - ).apply("max").asInstanceOf[Float] - - assert(floatVal2 == Float.PositiveInfinity) - } - - test("Double infinite values can be parsed") { - val negativeInfOptions = new CSVOptions(Map("negativeInf" -> "max"), false, "GMT") - val doubleVal1 = parser.makeConverter( - "_1", DoubleType, nullable = true, options = negativeInfOptions - ).apply("max").asInstanceOf[Double] - - assert(doubleVal1 == Double.NegativeInfinity) - - val positiveInfOptions = new CSVOptions(Map("positiveInf" -> "max"), false, "GMT") - val doubleVal2 = parser.makeConverter( - "_1", DoubleType, nullable = true, options = positiveInfOptions - ).apply("max").asInstanceOf[Double] - - assert(doubleVal2 == Double.PositiveInfinity) - } - -} From 2e98585782db3ecbb65f53ca9696db56333af173 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sun, 16 Sep 2018 21:12:58 +0200 Subject: [PATCH 03/26] Added an expression test --- .../sql/catalyst/csv/CSVInferSchema.scala | 25 +++++++++----- .../catalyst/expressions/csvExpressions.scala | 33 +++++++++++++++++++ .../expressions/CsvExpressionsSuite.scala | 3 ++ 3 files changed, 52 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala index 20883f7a95160..566edb44f6123 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.csv import java.math.BigDecimal -import scala.util.control.Exception._ +import scala.util.control.Exception.allCatch import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.TypeCoercion @@ -43,13 +43,7 @@ private[sql] object CSVInferSchema { val rootTypes: Array[DataType] = tokenRDD.aggregate(startType)(inferRowType(options), mergeRowTypes) - header.zip(rootTypes).map { case (thisHeader, rootType) => - val dType = rootType match { - case _: NullType => StringType - case other => other - } - StructField(thisHeader, dType, nullable = true) - } + toStructFields(rootTypes, header, options) } else { // By default fields are assumed to be StringType header.map(fieldName => StructField(fieldName, StringType, nullable = true)) @@ -58,7 +52,20 @@ private[sql] object CSVInferSchema { StructType(fields) } - private def inferRowType(options: CSVOptions) + def toStructFields( + fieldTypes: Array[DataType], + header: Array[String], + options: CSVOptions): Array[StructField] = { + header.zip(fieldTypes).map { case (thisHeader, rootType) => + val dType = rootType match { + case _: NullType => StringType + case other => other + } + StructField(thisHeader, dType, nullable = true) + } + } + + def inferRowType(options: CSVOptions) (rowSoFar: Array[DataType], next: Array[String]): Array[DataType] = { var i = 0 while (i < math.min(rowSoFar.length, next.length)) { // May have columns on right missing. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala index 853b1ea6a5f1c..9d8422a084be0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import com.univocity.parsers.csv.CsvParser + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.csv._ @@ -120,3 +122,34 @@ case class CsvToStructs( override def prettyName: String = "from_csv" } + +/** + * A function infers schema of CSV string. + */ +@ExpressionDescription( + usage = "_FUNC_(csv[, options]) - Returns schema in the DDL format of CSV string.", + examples = """ + Examples: + > SELECT _FUNC_('1,abc'); + struct<_c0:int,_c1:string> + """, + since = "3.0.0") +case class SchemaOfCsv(child: Expression) + extends UnaryExpression with String2StringExpression with CodegenFallback { + + override def convert(v: UTF8String): UTF8String = { + val parsedOptions = new CSVOptions(Map.empty, true, "UTC") + val parser = new CsvParser(parsedOptions.asParserSettings) + val row = parser.parseLine(v.toString) + + if (row != null) { + val header = row.zipWithIndex.map { case (_, index) => s"_c$index" } + val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType) + val fieldTypes = CSVInferSchema.inferRowType(parsedOptions)(startType, row) + val st = StructType(CSVInferSchema.toStructFields(fieldTypes, header, parsedOptions)) + UTF8String.fromString(st.catalogString) + } else { + null + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala index 65987af710750..d481c1bfb7af4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala @@ -154,5 +154,8 @@ class CsvExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with P InternalRow(null)) }.getCause assert(exception.getMessage.contains("from_csv() doesn't support the DROPMALFORMED mode")) + + test("infer schema of CSV strings") { + checkEvaluation(SchemaOfCsv(Literal.create("1,abc")), "struct<_c0:int,_c1:string>") } } From 1636db54ff654e83c2f982a9b49eb720896c08a8 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Fri, 21 Sep 2018 17:03:39 +0200 Subject: [PATCH 04/26] Support options --- .../sql/catalyst/expressions/csvExpressions.scala | 14 +++++++++++--- .../catalyst/expressions/CsvExpressionsSuite.scala | 9 ++++++++- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala index 9d8422a084be0..52a64adc74f13 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala @@ -133,12 +133,20 @@ case class CsvToStructs( > SELECT _FUNC_('1,abc'); struct<_c0:int,_c1:string> """, - since = "3.0.0") -case class SchemaOfCsv(child: Expression) + since = "2.5.0") +case class SchemaOfCsv( + child: Expression, + options: Map[String, String]) extends UnaryExpression with String2StringExpression with CodegenFallback { + def this(child: Expression) = this(child, Map.empty[String, String]) + + def this(child: Expression, options: Expression) = this( + child = child, + options = ExprUtils.convertToMapData(options)) + override def convert(v: UTF8String): UTF8String = { - val parsedOptions = new CSVOptions(Map.empty, true, "UTC") + val parsedOptions = new CSVOptions(options, true, "UTC") val parser = new CsvParser(parsedOptions.asParserSettings) val row = parser.parseLine(v.toString) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala index d481c1bfb7af4..c4558244d4e75 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala @@ -156,6 +156,13 @@ class CsvExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with P assert(exception.getMessage.contains("from_csv() doesn't support the DROPMALFORMED mode")) test("infer schema of CSV strings") { - checkEvaluation(SchemaOfCsv(Literal.create("1,abc")), "struct<_c0:int,_c1:string>") + checkEvaluation(new SchemaOfCsv(Literal.create("1,abc")), "struct<_c0:int,_c1:string>") + } + + test("infer schema of CSV strings by using options") { + checkEvaluation( + new SchemaOfCsv(Literal.create("1|abc"), + CreateMap(Seq(Literal.create("delimiter"), Literal.create("|")))), + "struct<_c0:int,_c1:string>") } } From 1ab4e8bab2aa3d40e7cc16e65715d2b0f5a85283 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Fri, 21 Sep 2018 18:05:55 +0200 Subject: [PATCH 05/26] Register schema_of_csv and adding SQL tests --- .../catalyst/analysis/FunctionRegistry.scala | 3 ++- .../sql-tests/inputs/csv-functions.sql | 3 +++ .../sql-tests/results/csv-functions.sql.out | 17 +++++++++++++++++ 3 files changed, 22 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 38f5c02910f79..9ebb44d14d678 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -523,7 +523,8 @@ object FunctionRegistry { castAlias("string", StringType), // csv - expression[CsvToStructs]("from_csv") + expression[CsvToStructs]("from_csv"), + expression[SchemaOfCsv]("schema_of_csv") ) val builtin: SimpleFunctionRegistry = { diff --git a/sql/core/src/test/resources/sql-tests/inputs/csv-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/csv-functions.sql index d2214fd016028..146b8390b3bd5 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/csv-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/csv-functions.sql @@ -7,3 +7,6 @@ select from_csv('1', 'a InvalidType'); select from_csv('1', 'a INT', named_struct('mode', 'PERMISSIVE')); select from_csv('1', 'a INT', map('mode', 1)); select from_csv(); +-- infer schema of json literal +select schema_of_csv('1,abc'); +select schema_of_csv('1|abc', map('delimiter', '|')); diff --git a/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out index f19f34a773c16..4786b7287bf47 100644 --- a/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out @@ -67,3 +67,20 @@ struct<> -- !query 6 output org.apache.spark.sql.AnalysisException Invalid number of arguments for function from_csv. Expected: one of 2 and 3; Found: 0; line 1 pos 7 +-- Number of queries: 2 + + +-- !query 0 +select schema_of_csv('1,abc') +-- !query 0 schema +struct +-- !query 0 output +struct<_c0:int,_c1:string> + + +-- !query 1 +select schema_of_csv('1|abc', map('delimiter', '|')) +-- !query 1 schema +struct +-- !query 1 output +struct<_c0:int,_c1:string> From b843fec5a3b1989d048a4064e194119e8537e95f Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Fri, 21 Sep 2018 19:27:33 +0200 Subject: [PATCH 06/26] Adding schema_of_csv and tests --- .../org/apache/spark/sql/functions.scala | 25 +++++++++++++++++++ .../apache/spark/sql/CsvFunctionsSuite.scala | 5 ++++ 2 files changed, 30 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 757a3226855c5..5dae78d96e375 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3864,6 +3864,7 @@ object functions { @scala.annotation.varargs def map_concat(cols: Column*): Column = withExpr { MapConcat(cols.map(_.expr)) } +<<<<<<< HEAD /** * Parses a column containing a CSV string into a `StructType` with the specified schema. * Returns `null`, in the case of an unparseable string. @@ -3894,6 +3895,30 @@ object functions { */ def from_csv(e: Column, schema: Column, options: java.util.Map[String, String]): Column = { withExpr(new CsvToStructs(e.expr, schema.expr, options.asScala.toMap)) + + /** + * Parses a column containing a CSV string and infers its schema. + * + * @param e a string column containing CSV data. + * + * @group collection_funcs + * @since 2.5.0 + */ + def schema_of_csv(e: Column): Column = withExpr(new SchemaOfCsv(e.expr)) + + /** + * Parses a column containing a CSV string and infers its schema using options. + * + * @param e a string column containing CSV data. + * @param options options to control how the CSV is parsed. accepts the same options and the + * json data source. See [[DataFrameReader#csv]]. + * @return a column with string literal containing schema in DDL format. + * + * @group collection_funcs + * @since 2.5.0 + */ + def schema_of_csv(e: Column, options: java.util.Map[String, String]): Column = { + withExpr(SchemaOfCsv(e.expr, options.asScala.toMap)) } // scalastyle:off line.size.limit diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala index 38a2143d6d0f0..f81d895f37088 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala @@ -58,5 +58,10 @@ class CsvFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(df2, Seq( Row(Row(null, null, "0,2013-111-11 12:13:14")), Row(Row(1, java.sql.Date.valueOf("1983-08-04"), null)))) + + test("infers schemas using options") { + val df = spark.range(1) + .select(schema_of_csv(lit("0.1 1"), Map("sep" -> " ").asJava)) + checkAnswer(df, Seq(Row("struct<_c0:double,_c1:int>"))) } } From 101739fb9afbe58bc1ebc99aec121f1756496a79 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Fri, 21 Sep 2018 19:54:43 +0200 Subject: [PATCH 07/26] Support schema_of_csv in PySpark --- python/pyspark/sql/functions.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index ca2a256983d67..4a03e679fdd30 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2364,6 +2364,27 @@ def schema_of_json(json, options={}): return Column(jc) +@ignore_unicode_prefix +@since(2.5) +def schema_of_csv(col, options={}): + """ + Parses a column containing a CSV string and infers its schema in DDL format. + + :param col: string column in CSV format + :param options: options to control parsing. accepts the same options as the CSV datasource + + >>> from pyspark.sql.types import * + >>> data = [(1, '1|a')] + >>> df = spark.createDataFrame(data, ("key", "value")) + >>> df.select(schema_of_csv(df.value, {'sep':'|'}).alias("csv")).collect() + [Row(csv=u'struct<_c0:int,_c1:string>')] + """ + + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.schema_of_csv(_to_java_column(col), options) + return Column(jc) + + @since(1.5) def size(col): """ From 8b9a1a4c91c95e0e406fa9e40fdda12e6e9d4161 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sun, 7 Oct 2018 20:49:00 +0200 Subject: [PATCH 08/26] 2.5 -> 3.0 --- python/pyspark/sql/functions.py | 2 +- .../spark/sql/catalyst/expressions/csvExpressions.scala | 2 +- sql/core/src/main/scala/org/apache/spark/sql/functions.scala | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 4a03e679fdd30..ec6dc9a0e7891 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2365,7 +2365,7 @@ def schema_of_json(json, options={}): @ignore_unicode_prefix -@since(2.5) +@since(3.0) def schema_of_csv(col, options={}): """ Parses a column containing a CSV string and infers its schema in DDL format. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala index 52a64adc74f13..ecf49a118ace6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala @@ -133,7 +133,7 @@ case class CsvToStructs( > SELECT _FUNC_('1,abc'); struct<_c0:int,_c1:string> """, - since = "2.5.0") + since = "3.0.0") case class SchemaOfCsv( child: Expression, options: Map[String, String]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 5dae78d96e375..cb0e713ce3bdd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3902,7 +3902,7 @@ object functions { * @param e a string column containing CSV data. * * @group collection_funcs - * @since 2.5.0 + * @since 3.0.0 */ def schema_of_csv(e: Column): Column = withExpr(new SchemaOfCsv(e.expr)) @@ -3915,7 +3915,7 @@ object functions { * @return a column with string literal containing schema in DDL format. * * @group collection_funcs - * @since 2.5.0 + * @since 3.0.0 */ def schema_of_csv(e: Column, options: java.util.Map[String, String]): Column = { withExpr(SchemaOfCsv(e.expr, options.asScala.toMap)) From 9a1bb079c35122ec7c83c7cfb11688d0fe0a93f8 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Fri, 12 Oct 2018 19:08:58 +0200 Subject: [PATCH 09/26] Merging the master branch --- .../sql/catalyst/csv/CSVHeaderChecker.scala | 4 ++-- .../spark/sql/catalyst/csv/CSVUtils.scala | 24 +++++++++++++++++++ .../sql/catalyst/csv/UnivocityParser.scala | 2 +- .../datasources/csv/CSVDataSource.scala | 2 +- 4 files changed, 28 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVHeaderChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVHeaderChecker.scala index c39f77e891ae1..83c517ce3deef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVHeaderChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVHeaderChecker.scala @@ -107,7 +107,7 @@ class CSVHeaderChecker( } // This is currently only used to parse CSV with multiLine mode. - private[csv] def checkHeaderColumnNames(tokenizer: CsvParser): Unit = { + private[sql] def checkHeaderColumnNames(tokenizer: CsvParser): Unit = { assert(options.multiLine, "This method should be executed with multiLine.") if (options.headerFlag) { val firstRecord = tokenizer.parseNext() @@ -116,7 +116,7 @@ class CSVHeaderChecker( } // This is currently only used to parse CSV with non-multiLine mode. - private[csv] def checkHeaderColumnNames(lines: Iterator[String], tokenizer: CsvParser): Unit = { + private[sql] def checkHeaderColumnNames(lines: Iterator[String], tokenizer: CsvParser): Unit = { assert(!options.multiLine, "This method should not be executed with multiline.") // Checking that column names in the header are matched to field names of the schema. // The header will be removed from lines. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVUtils.scala index a0401c6f4ff03..769740d60a3cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVUtils.scala @@ -18,6 +18,30 @@ package org.apache.spark.sql.catalyst.csv object CSVUtils { + + def skipComments(iter: Iterator[String], options: CSVOptions): Iterator[String] = { + if (options.isCommentSet) { + val commentPrefix = options.comment.toString + iter.dropWhile { line => + line.trim.isEmpty || line.trim.startsWith(commentPrefix) + } + } else { + iter.dropWhile(_.trim.isEmpty) + } + } + + /** + * Extracts header and moves iterator forward so that only data remains in it + */ + def extractHeader(iter: Iterator[String], options: CSVOptions): Option[String] = { + val nonEmptyLines = skipComments(iter, options) + if (nonEmptyLines.hasNext) { + Some(nonEmptyLines.next()) + } else { + None + } + } + /** * Helper method that converts string representation of a character to actual character. * It handles some Java escaped strings and throws exception if given string is longer than one diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala index 8669610958e68..adc12a2d4334f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala @@ -27,7 +27,7 @@ import com.univocity.parsers.csv.CsvParser import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.csv.CSVOptions +import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions} import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.catalyst.util.{BadRecordException, DateTimeUtils, FailureSafeParser} import org.apache.spark.sql.types._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 9e7b45db9f280..26bebc6ad869a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -34,7 +34,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, UnivocityParser} +import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, CSVInferSchema, UnivocityParser} import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.types.StructType From 9efa823ee86e0c8df5cec7f7b11a12ee6e1da28b Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sat, 13 Oct 2018 09:11:27 +0200 Subject: [PATCH 10/26] Moving toChar to sql/catalyst --- .../spark/sql/catalyst/csv/CSVUtils.scala | 36 ++++++++++--------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVUtils.scala index 769740d60a3cf..fc1d454c984f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVUtils.scala @@ -49,23 +49,25 @@ object CSVUtils { */ @throws[IllegalArgumentException] def toChar(str: String): Char = { - if (str.charAt(0) == '\\') { - str.charAt(1) - match { - case 't' => '\t' - case 'r' => '\r' - case 'b' => '\b' - case 'f' => '\f' - case '\"' => '\"' // In case user changes quote char and uses \" as delimiter in options - case '\'' => '\'' - case 'u' if str == """\u0000""" => '\u0000' - case _ => - throw new IllegalArgumentException(s"Unsupported special character for delimiter: $str") - } - } else if (str.length == 1) { - str.charAt(0) - } else { - throw new IllegalArgumentException(s"Delimiter cannot be more than one character: $str") + (str: Seq[Char]) match { + case Seq() => throw new IllegalArgumentException("Delimiter cannot be empty string") + case Seq('\\') => throw new IllegalArgumentException("Single backslash is prohibited." + + " It has special meaning as beginning of an escape sequence." + + " To get the backslash character, pass a string with two backslashes as the delimiter.") + case Seq(c) => c + case Seq('\\', 't') => '\t' + case Seq('\\', 'r') => '\r' + case Seq('\\', 'b') => '\b' + case Seq('\\', 'f') => '\f' + // In case user changes quote char and uses \" as delimiter in options + case Seq('\\', '\"') => '\"' + case Seq('\\', '\'') => '\'' + case Seq('\\', '\\') => '\\' + case _ if str == """\u0000""" => '\u0000' + case Seq('\\', _) => + throw new IllegalArgumentException(s"Unsupported special character for delimiter: $str") + case _ => + throw new IllegalArgumentException(s"Delimiter cannot be more than one character: $str") } } } From e3d39c3d385e9c4bfb99c895f6708fe3251d5d0b Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 19 Oct 2018 11:47:25 +0800 Subject: [PATCH 11/26] build fix --- .../execution/datasources/csv/CSVUtils.scala | 41 ++++ .../org/apache/spark/sql/functions.scala | 1 - .../csv/UnivocityParserSuite.scala | 200 ++++++++++++++++++ 3 files changed, 241 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala index 5bd19ce7f77ac..ef3d21dff8144 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala @@ -59,6 +59,47 @@ object CSVUtils { } } + /** + * Generates a header from the given row which is null-safe and duplicate-safe. + */ + def makeSafeHeader( + row: Array[String], + caseSensitive: Boolean, + options: CSVOptions): Array[String] = { + if (options.headerFlag) { + val duplicates = { + val headerNames = row.filter(_ != null) + // scalastyle:off caselocale + .map(name => if (caseSensitive) name else name.toLowerCase) + // scalastyle:on caselocale + headerNames.diff(headerNames.distinct).distinct + } + row.zipWithIndex.map { case (value, index) => + if (value == null || value.isEmpty || value == options.nullValue) { + // When there are empty strings or the values set in `nullValue`, put the + // index as the suffix. + s"_c$index" + // scalastyle:off caselocale + } else if (!caseSensitive && duplicates.contains(value.toLowerCase)) { + // scalastyle:on caselocale + // When there are case-insensitive duplicates, put the index as the suffix. + s"$value$index" + } else if (duplicates.contains(value)) { + // When there are duplicates, put the index as the suffix. + s"$value$index" + } else { + value + } + } + } else { + row.zipWithIndex.map { case (_, index) => + // Uses default column names, "_c#" where # is its position of fields + // when header option is disabled. + s"_c$index" + } + } + } + /** * Sample CSV dataset as configured by `samplingRatio`. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index cb0e713ce3bdd..199ab39487a68 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3864,7 +3864,6 @@ object functions { @scala.annotation.varargs def map_concat(cols: Column*): Column = withExpr { MapConcat(cols.map(_.expr)) } -<<<<<<< HEAD /** * Parses a column containing a CSV string into a `StructType` with the specified schema. * Returns `null`, in the case of an unparseable string. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala index e69de29bb2d1d..6f231142949d1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import java.math.BigDecimal + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.csv.{CSVOptions, UnivocityParser} +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +class UnivocityParserSuite extends SparkFunSuite { + private val parser = new UnivocityParser( + StructType(Seq.empty), + new CSVOptions(Map.empty[String, String], false, "GMT")) + + private def assertNull(v: Any) = assert(v == null) + + test("Can parse decimal type values") { + val stringValues = Seq("10.05", "1,000.01", "158,058,049.001") + val decimalValues = Seq(10.05, 1000.01, 158058049.001) + val decimalType = new DecimalType() + + stringValues.zip(decimalValues).foreach { case (strVal, decimalVal) => + val decimalValue = new BigDecimal(decimalVal.toString) + val options = new CSVOptions(Map.empty[String, String], false, "GMT") + assert(parser.makeConverter("_1", decimalType, options = options).apply(strVal) === + Decimal(decimalValue, decimalType.precision, decimalType.scale)) + } + } + + test("Nullable types are handled") { + val types = Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, + BooleanType, DecimalType.DoubleDecimal, TimestampType, DateType, StringType) + + // Nullable field with nullValue option. + types.foreach { t => + // Tests that a custom nullValue. + val nullValueOptions = new CSVOptions(Map("nullValue" -> "-"), false, "GMT") + val converter = + parser.makeConverter("_1", t, nullable = true, options = nullValueOptions) + assertNull(converter.apply("-")) + assertNull(converter.apply(null)) + + // Tests that the default nullValue is empty string. + val options = new CSVOptions(Map.empty[String, String], false, "GMT") + assertNull(parser.makeConverter("_1", t, nullable = true, options = options).apply("")) + } + + // Not nullable field with nullValue option. + types.foreach { t => + // Casts a null to not nullable field should throw an exception. + val options = new CSVOptions(Map("nullValue" -> "-"), false, "GMT") + val converter = + parser.makeConverter("_1", t, nullable = false, options = options) + var message = intercept[RuntimeException] { + converter.apply("-") + }.getMessage + assert(message.contains("null value found but field _1 is not nullable.")) + message = intercept[RuntimeException] { + converter.apply(null) + }.getMessage + assert(message.contains("null value found but field _1 is not nullable.")) + } + + // If nullValue is different with empty string, then, empty string should not be casted into + // null. + Seq(true, false).foreach { b => + val options = new CSVOptions(Map("nullValue" -> "null"), false, "GMT") + val converter = + parser.makeConverter("_1", StringType, nullable = b, options = options) + assert(converter.apply("") == UTF8String.fromString("")) + } + } + + test("Throws exception for empty string with non null type") { + val options = new CSVOptions(Map.empty[String, String], false, "GMT") + val exception = intercept[RuntimeException]{ + parser.makeConverter("_1", IntegerType, nullable = false, options = options).apply("") + } + assert(exception.getMessage.contains("null value found but field _1 is not nullable.")) + } + + test("Types are cast correctly") { + val options = new CSVOptions(Map.empty[String, String], false, "GMT") + assert(parser.makeConverter("_1", ByteType, options = options).apply("10") == 10) + assert(parser.makeConverter("_1", ShortType, options = options).apply("10") == 10) + assert(parser.makeConverter("_1", IntegerType, options = options).apply("10") == 10) + assert(parser.makeConverter("_1", LongType, options = options).apply("10") == 10) + assert(parser.makeConverter("_1", FloatType, options = options).apply("1.00") == 1.0) + assert(parser.makeConverter("_1", DoubleType, options = options).apply("1.00") == 1.0) + assert(parser.makeConverter("_1", BooleanType, options = options).apply("true") == true) + + val timestampsOptions = + new CSVOptions(Map("timestampFormat" -> "dd/MM/yyyy hh:mm"), false, "GMT") + val customTimestamp = "31/01/2015 00:00" + val expectedTime = timestampsOptions.timestampFormat.parse(customTimestamp).getTime + val castedTimestamp = + parser.makeConverter("_1", TimestampType, nullable = true, options = timestampsOptions) + .apply(customTimestamp) + assert(castedTimestamp == expectedTime * 1000L) + + val customDate = "31/01/2015" + val dateOptions = new CSVOptions(Map("dateFormat" -> "dd/MM/yyyy"), false, "GMT") + val expectedDate = dateOptions.dateFormat.parse(customDate).getTime + val castedDate = + parser.makeConverter("_1", DateType, nullable = true, options = dateOptions) + .apply(customTimestamp) + assert(castedDate == DateTimeUtils.millisToDays(expectedDate)) + + val timestamp = "2015-01-01 00:00:00" + assert(parser.makeConverter("_1", TimestampType, options = options).apply(timestamp) == + DateTimeUtils.stringToTime(timestamp).getTime * 1000L) + assert(parser.makeConverter("_1", DateType, options = options).apply("2015-01-01") == + DateTimeUtils.millisToDays(DateTimeUtils.stringToTime("2015-01-01").getTime)) + } + + test("Throws exception for casting an invalid string to Float and Double Types") { + val options = new CSVOptions(Map.empty[String, String], false, "GMT") + val types = Seq(DoubleType, FloatType) + val input = Seq("10u000", "abc", "1 2/3") + types.foreach { dt => + input.foreach { v => + val message = intercept[NumberFormatException] { + parser.makeConverter("_1", dt, options = options).apply(v) + }.getMessage + assert(message.contains(v)) + } + } + } + + test("Float NaN values are parsed correctly") { + val options = new CSVOptions(Map("nanValue" -> "nn"), false, "GMT") + val floatVal: Float = parser.makeConverter( + "_1", FloatType, nullable = true, options = options + ).apply("nn").asInstanceOf[Float] + + // Java implements the IEEE-754 floating point standard which guarantees that any comparison + // against NaN will return false (except != which returns true) + assert(floatVal != floatVal) + } + + test("Double NaN values are parsed correctly") { + val options = new CSVOptions(Map("nanValue" -> "-"), false, "GMT") + val doubleVal: Double = parser.makeConverter( + "_1", DoubleType, nullable = true, options = options + ).apply("-").asInstanceOf[Double] + + assert(doubleVal.isNaN) + } + + test("Float infinite values can be parsed") { + val negativeInfOptions = new CSVOptions(Map("negativeInf" -> "max"), false, "GMT") + val floatVal1 = parser.makeConverter( + "_1", FloatType, nullable = true, options = negativeInfOptions + ).apply("max").asInstanceOf[Float] + + assert(floatVal1 == Float.NegativeInfinity) + + val positiveInfOptions = new CSVOptions(Map("positiveInf" -> "max"), false, "GMT") + val floatVal2 = parser.makeConverter( + "_1", FloatType, nullable = true, options = positiveInfOptions + ).apply("max").asInstanceOf[Float] + + assert(floatVal2 == Float.PositiveInfinity) + } + + test("Double infinite values can be parsed") { + val negativeInfOptions = new CSVOptions(Map("negativeInf" -> "max"), false, "GMT") + val doubleVal1 = parser.makeConverter( + "_1", DoubleType, nullable = true, options = negativeInfOptions + ).apply("max").asInstanceOf[Double] + + assert(doubleVal1 == Double.NegativeInfinity) + + val positiveInfOptions = new CSVOptions(Map("positiveInf" -> "max"), false, "GMT") + val doubleVal2 = parser.makeConverter( + "_1", DoubleType, nullable = true, options = positiveInfOptions + ).apply("max").asInstanceOf[Double] + + assert(doubleVal2 == Double.PositiveInfinity) + } + +} From 587caaece9ff759a879a0628b153d41fcb4231d8 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 19 Oct 2018 11:47:50 +0800 Subject: [PATCH 12/26] build fix --- .../spark/sql/catalyst/expressions/CsvExpressionsSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala index c4558244d4e75..6ea0c83d5e45b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala @@ -154,6 +154,7 @@ class CsvExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with P InternalRow(null)) }.getCause assert(exception.getMessage.contains("from_csv() doesn't support the DROPMALFORMED mode")) + } test("infer schema of CSV strings") { checkEvaluation(new SchemaOfCsv(Literal.create("1,abc")), "struct<_c0:int,_c1:string>") From d44a31929ce52a1d3d9763bd972da2b6899ab322 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 19 Oct 2018 11:48:51 +0800 Subject: [PATCH 13/26] build fix --- sql/core/src/main/scala/org/apache/spark/sql/functions.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 199ab39487a68..f8d733199e145 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3894,6 +3894,7 @@ object functions { */ def from_csv(e: Column, schema: Column, options: java.util.Map[String, String]): Column = { withExpr(new CsvToStructs(e.expr, schema.expr, options.asScala.toMap)) + } /** * Parses a column containing a CSV string and infers its schema. From 513f0587336a02e043915928c81e123eb97e3aa0 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 19 Oct 2018 11:54:56 +0800 Subject: [PATCH 14/26] build fix --- .../apache/spark/sql/catalyst}/csv/UnivocityParserSuite.scala | 3 +-- .../test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) rename sql/{core/src/test/scala/org/apache/spark/sql/execution/datasources => catalyst/src/test/scala/org/apache/spark/sql/catalyst}/csv/UnivocityParserSuite.scala (98%) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala similarity index 98% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala index 6f231142949d1..e4e7dc2e8c0e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala @@ -15,12 +15,11 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.datasources.csv +package org.apache.spark.sql.catalyst.csv import java.math.BigDecimal import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.csv.{CSVOptions, UnivocityParser} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala index f81d895f37088..c237bc2436618 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala @@ -58,6 +58,7 @@ class CsvFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(df2, Seq( Row(Row(null, null, "0,2013-111-11 12:13:14")), Row(Row(1, java.sql.Date.valueOf("1983-08-04"), null)))) + } test("infers schemas using options") { val df = spark.range(1) From ee26f0d18b119abda9b62ebbe875219571dbb495 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 19 Oct 2018 12:04:05 +0800 Subject: [PATCH 15/26] fix style and build --- .../spark/sql/catalyst/csv/CSVUtils.scala | 73 ------------------- .../sql/catalyst/csv/CSVUtilsSuite.scala | 47 ------------ .../datasources/csv/CSVDataSource.scala | 2 +- .../execution/datasources/csv/CSVUtils.scala | 5 +- 4 files changed, 4 insertions(+), 123 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVUtils.scala delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVUtilsSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVUtils.scala deleted file mode 100644 index fc1d454c984f8..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVUtils.scala +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.csv - -object CSVUtils { - - def skipComments(iter: Iterator[String], options: CSVOptions): Iterator[String] = { - if (options.isCommentSet) { - val commentPrefix = options.comment.toString - iter.dropWhile { line => - line.trim.isEmpty || line.trim.startsWith(commentPrefix) - } - } else { - iter.dropWhile(_.trim.isEmpty) - } - } - - /** - * Extracts header and moves iterator forward so that only data remains in it - */ - def extractHeader(iter: Iterator[String], options: CSVOptions): Option[String] = { - val nonEmptyLines = skipComments(iter, options) - if (nonEmptyLines.hasNext) { - Some(nonEmptyLines.next()) - } else { - None - } - } - - /** - * Helper method that converts string representation of a character to actual character. - * It handles some Java escaped strings and throws exception if given string is longer than one - * character. - */ - @throws[IllegalArgumentException] - def toChar(str: String): Char = { - (str: Seq[Char]) match { - case Seq() => throw new IllegalArgumentException("Delimiter cannot be empty string") - case Seq('\\') => throw new IllegalArgumentException("Single backslash is prohibited." + - " It has special meaning as beginning of an escape sequence." + - " To get the backslash character, pass a string with two backslashes as the delimiter.") - case Seq(c) => c - case Seq('\\', 't') => '\t' - case Seq('\\', 'r') => '\r' - case Seq('\\', 'b') => '\b' - case Seq('\\', 'f') => '\f' - // In case user changes quote char and uses \" as delimiter in options - case Seq('\\', '\"') => '\"' - case Seq('\\', '\'') => '\'' - case Seq('\\', '\\') => '\\' - case _ if str == """\u0000""" => '\u0000' - case Seq('\\', _) => - throw new IllegalArgumentException(s"Unsupported special character for delimiter: $str") - case _ => - throw new IllegalArgumentException(s"Delimiter cannot be more than one character: $str") - } - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVUtilsSuite.scala deleted file mode 100644 index 3217df9aed335..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVUtilsSuite.scala +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.csv - -import org.apache.spark.SparkFunSuite - -class CSVUtilsSuite extends SparkFunSuite { - test("Can parse escaped characters") { - assert(CSVUtils.toChar("""\t""") === '\t') - assert(CSVUtils.toChar("""\r""") === '\r') - assert(CSVUtils.toChar("""\b""") === '\b') - assert(CSVUtils.toChar("""\f""") === '\f') - assert(CSVUtils.toChar("""\"""") === '\"') - assert(CSVUtils.toChar("""\'""") === '\'') - assert(CSVUtils.toChar("""\u0000""") === '\u0000') - } - - test("Does not accept delimiter larger than one character") { - val exception = intercept[IllegalArgumentException]{ - CSVUtils.toChar("ab") - } - assert(exception.getMessage.contains("cannot be more than one character")) - } - - test("Throws exception for unsupported escaped characters") { - val exception = intercept[IllegalArgumentException]{ - CSVUtils.toChar("""\1""") - } - assert(exception.getMessage.contains("Unsupported special character for delimiter")) - } - -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 26bebc6ad869a..4808e8ef042d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -34,7 +34,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, CSVInferSchema, UnivocityParser} +import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVInferSchema, CSVOptions, UnivocityParser} import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.types.StructType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala index ef3d21dff8144..21fabac472f4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala @@ -60,8 +60,8 @@ object CSVUtils { } /** - * Generates a header from the given row which is null-safe and duplicate-safe. - */ + * Generates a header from the given row which is null-safe and duplicate-safe. + */ def makeSafeHeader( row: Array[String], caseSensitive: Boolean, @@ -74,6 +74,7 @@ object CSVUtils { // scalastyle:on caselocale headerNames.diff(headerNames.distinct).distinct } + row.zipWithIndex.map { case (value, index) => if (value == null || value.isEmpty || value == options.nullValue) { // When there are empty strings or the values set in `nullValue`, put the From 3580c2cca8c7003ac387286d612fe3f73d3dded4 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 19 Oct 2018 12:06:28 +0800 Subject: [PATCH 16/26] fix style, build resolve conflicts --- python/pyspark/sql/functions.py | 1 - .../org/apache/spark/sql/catalyst/csv/CSVHeaderChecker.scala | 4 ++-- .../org/apache/spark/sql/catalyst/csv/UnivocityParser.scala | 1 - 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index ec6dc9a0e7891..f36d81f88080c 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2373,7 +2373,6 @@ def schema_of_csv(col, options={}): :param col: string column in CSV format :param options: options to control parsing. accepts the same options as the CSV datasource - >>> from pyspark.sql.types import * >>> data = [(1, '1|a')] >>> df = spark.createDataFrame(data, ("key", "value")) >>> df.select(schema_of_csv(df.value, {'sep':'|'}).alias("csv")).collect() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVHeaderChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVHeaderChecker.scala index 83c517ce3deef..c39f77e891ae1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVHeaderChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVHeaderChecker.scala @@ -107,7 +107,7 @@ class CSVHeaderChecker( } // This is currently only used to parse CSV with multiLine mode. - private[sql] def checkHeaderColumnNames(tokenizer: CsvParser): Unit = { + private[csv] def checkHeaderColumnNames(tokenizer: CsvParser): Unit = { assert(options.multiLine, "This method should be executed with multiLine.") if (options.headerFlag) { val firstRecord = tokenizer.parseNext() @@ -116,7 +116,7 @@ class CSVHeaderChecker( } // This is currently only used to parse CSV with non-multiLine mode. - private[sql] def checkHeaderColumnNames(lines: Iterator[String], tokenizer: CsvParser): Unit = { + private[csv] def checkHeaderColumnNames(lines: Iterator[String], tokenizer: CsvParser): Unit = { assert(!options.multiLine, "This method should not be executed with multiline.") // Checking that column names in the header are matched to field names of the schema. // The header will be removed from lines. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala index adc12a2d4334f..46ed58ed92830 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala @@ -27,7 +27,6 @@ import com.univocity.parsers.csv.CsvParser import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions} import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.catalyst.util.{BadRecordException, DateTimeUtils, FailureSafeParser} import org.apache.spark.sql.types._ From cd7cfdf502f3dacac37f60749af445ce2f7af691 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 19 Oct 2018 13:14:25 +0800 Subject: [PATCH 17/26] dedup JsonExprUtils and add prettyName --- .../sql/catalyst/expressions/ExprUtils.scala | 15 ++++++++++++++- .../catalyst/expressions/csvExpressions.scala | 6 ++++-- .../expressions/jsonExpressions.scala | 16 ++-------------- .../sql-tests/results/csv-functions.sql.out | 19 +++++++++---------- .../sql-tests/results/json-functions.sql.out | 2 +- 5 files changed, 30 insertions(+), 28 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala index e5708894f22b4..fcdd296f1e80a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala @@ -19,12 +19,25 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.util.ArrayBasedMapData -import org.apache.spark.sql.types.{MapType, StringType, StructType} +import org.apache.spark.sql.types.{DataType, MapType, StringType, StructType} +import org.apache.spark.unsafe.types.UTF8String object ExprUtils { def evalSchemaExpr(exp: Expression): StructType = exp match { case Literal(s, StringType) => StructType.fromDDL(s.toString) + case e @ SchemaOfCsv(_: Literal, _) => + val ddlSchema = e.eval().asInstanceOf[UTF8String] + StructType.fromDDL(ddlSchema.toString) + case e => throw new AnalysisException( + s"Schema should be specified in DDL format as a string literal instead of ${e.sql}") + } + + def evalTypeExpr(exp: Expression): DataType = exp match { + case Literal(s, StringType) => DataType.fromDDL(s.toString) + case e @ SchemaOfJson(_: Literal, _) => + val ddlSchema = e.eval().asInstanceOf[UTF8String] + DataType.fromDDL(ddlSchema.toString) case e => throw new AnalysisException( s"Schema should be specified in DDL format as a string literal instead of ${e.sql}") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala index ecf49a118ace6..6838f914e4204 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala @@ -62,7 +62,7 @@ case class CsvToStructs( // Used in `FunctionRegistry` def this(child: Expression, schema: Expression, options: Map[String, String]) = this( - schema = ExprUtils.evalSchemaExpr(schema), + schema = ExprUtils.evalSchemaExpr(schema).asInstanceOf[StructType], options = options, child = child, timeZoneId = None) @@ -71,7 +71,7 @@ case class CsvToStructs( def this(child: Expression, schema: Expression, options: Expression) = this( - schema = ExprUtils.evalSchemaExpr(schema), + schema = ExprUtils.evalSchemaExpr(schema).asInstanceOf[StructType], options = ExprUtils.convertToMapData(options), child = child, timeZoneId = None) @@ -160,4 +160,6 @@ case class SchemaOfCsv( null } } + + override def prettyName: String = "schema_of_csv" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 77af5906010f3..eafcb6161036e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -529,7 +529,7 @@ case class JsonToStructs( // Used in `FunctionRegistry` def this(child: Expression, schema: Expression, options: Map[String, String]) = this( - schema = JsonExprUtils.evalSchemaExpr(schema), + schema = ExprUtils.evalTypeExpr(schema), options = options, child = child, timeZoneId = None) @@ -538,7 +538,7 @@ case class JsonToStructs( def this(child: Expression, schema: Expression, options: Expression) = this( - schema = JsonExprUtils.evalSchemaExpr(schema), + schema = ExprUtils.evalTypeExpr(schema), options = ExprUtils.convertToMapData(options), child = child, timeZoneId = None) @@ -784,15 +784,3 @@ case class SchemaOfJson( override def prettyName: String = "schema_of_json" } - -object JsonExprUtils { - def evalSchemaExpr(exp: Expression): DataType = exp match { - case Literal(s, StringType) => DataType.fromDDL(s.toString) - case e @ SchemaOfJson(_: Literal, _) => - val ddlSchema = e.eval(EmptyRow).asInstanceOf[UTF8String] - DataType.fromDDL(ddlSchema.toString) - case e => throw new AnalysisException( - "Schema should be specified in DDL format as a string literal" + - s" or output of the schema_of_json function instead of ${e.sql}") - } -} diff --git a/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out index 4786b7287bf47..e741f73fcdf25 100644 --- a/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 7 +-- Number of queries: 9 -- !query 0 @@ -67,20 +67,19 @@ struct<> -- !query 6 output org.apache.spark.sql.AnalysisException Invalid number of arguments for function from_csv. Expected: one of 2 and 3; Found: 0; line 1 pos 7 --- Number of queries: 2 --- !query 0 +-- !query 7 select schema_of_csv('1,abc') --- !query 0 schema -struct --- !query 0 output +-- !query 7 schema +struct +-- !query 7 output struct<_c0:int,_c1:string> --- !query 1 +-- !query 8 select schema_of_csv('1|abc', map('delimiter', '|')) --- !query 1 schema -struct --- !query 1 output +-- !query 8 schema +struct +-- !query 8 output struct<_c0:int,_c1:string> diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out index ca0cd90d94fa7..fa8c4d4bc2a08 100644 --- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -157,7 +157,7 @@ select from_json() struct<> -- !query 17 output org.apache.spark.sql.AnalysisException -Invalid number of arguments for function from_json. Expected: one of 2 and 3; Found: 0; line 1 pos 7 +Schema should be specified in DDL format as a string literal instead of 1;; line 1 pos 7 -- !query 18 From fef8a9ec807f58cab37ce762a3b0d114d3f6254d Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 19 Oct 2018 13:44:34 +0800 Subject: [PATCH 18/26] Address nits and error messages --- .../org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala | 2 +- .../apache/spark/sql/catalyst/expressions/ExprUtils.scala | 6 ++++-- .../test/resources/sql-tests/results/csv-functions.sql.out | 2 +- .../test/resources/sql-tests/results/json-functions.sql.out | 2 +- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala index 566edb44f6123..799e9994451b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ -private[sql] object CSVInferSchema { +object CSVInferSchema { /** * Similar to the JSON schema inference diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala index fcdd296f1e80a..94e249b1468dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala @@ -30,7 +30,8 @@ object ExprUtils { val ddlSchema = e.eval().asInstanceOf[UTF8String] StructType.fromDDL(ddlSchema.toString) case e => throw new AnalysisException( - s"Schema should be specified in DDL format as a string literal instead of ${e.sql}") + "Schema should be specified in DDL format as a string literal or output of " + + s"the schema_of_csv function instead of ${e.sql}") } def evalTypeExpr(exp: Expression): DataType = exp match { @@ -39,7 +40,8 @@ object ExprUtils { val ddlSchema = e.eval().asInstanceOf[UTF8String] DataType.fromDDL(ddlSchema.toString) case e => throw new AnalysisException( - s"Schema should be specified in DDL format as a string literal instead of ${e.sql}") + "Schema should be specified in DDL format as a string literal or output of " + + s"the schema_of_json function instead of ${e.sql}") } def convertToMapData(exp: Expression): Map[String, String] = exp match { diff --git a/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out index e741f73fcdf25..2bee34ca5a079 100644 --- a/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out @@ -24,7 +24,7 @@ select from_csv('1', 1) struct<> -- !query 2 output org.apache.spark.sql.AnalysisException -Schema should be specified in DDL format as a string literal instead of 1;; line 1 pos 7 +Schema should be specified in DDL format as a string literal or output of the schema_of_csv function instead of 1;; line 1 pos 7 -- !query 3 diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out index fa8c4d4bc2a08..cf34f7c6daf0b 100644 --- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -157,7 +157,7 @@ select from_json() struct<> -- !query 17 output org.apache.spark.sql.AnalysisException -Schema should be specified in DDL format as a string literal instead of 1;; line 1 pos 7 +Schema should be specified in DDL format as a string literal or output of the schema_of_json function instead of 1;; line 1 pos 7 -- !query 18 From 21e2dc4715a640a2a6fc1c349feafa090e9bd1fb Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 19 Oct 2018 13:52:22 +0800 Subject: [PATCH 19/26] Fix doctest examples to be more uesful --- python/pyspark/sql/functions.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f36d81f88080c..c3113a6b004d1 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2684,13 +2684,13 @@ def from_csv(col, schema, options={}): :param schema: a string with schema in DDL format to use when parsing the CSV column. :param options: options to control parsing. accepts the same options as the CSV datasource - >>> data = [(1, '1')] - >>> df = spark.createDataFrame(data, ("key", "value")) - >>> df.select(from_csv(df.value, "a INT").alias("csv")).collect() - [Row(csv=Row(a=1))] - >>> df = spark.createDataFrame(data, ("key", "value")) - >>> df.select(from_csv(df.value, lit("a INT")).alias("csv")).collect() - [Row(csv=Row(a=1))] + >>> data = [('1,2,3',)] + >>> df = spark.createDataFrame(data, ("value",)) + >>> df.select(from_csv(df.value, "a INT, b INT, c INT").alias("csv")).collect() + [Row(csv=Row(a=1, b=2, c=3))] + >>> value = data[0][0] + >>> df.select(from_csv(df.value, schema_of_csv(value)).alias("csv")).collect() + [Row(csv=Row(_c0=1, _c1=2, _c2=3))] """ sc = SparkContext._active_spark_context From 6b1f408f4bc8f593342bfc93d7cdce898753549f Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 19 Oct 2018 14:36:16 +0800 Subject: [PATCH 20/26] Deduplicate and fix python examples to be more useful --- python/pyspark/sql/functions.py | 10 +++++-- .../sql/catalyst/expressions/ExprUtils.scala | 26 +++++++++++++------ 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index c3113a6b004d1..8ac09c00212c2 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2378,9 +2378,15 @@ def schema_of_csv(col, options={}): >>> df.select(schema_of_csv(df.value, {'sep':'|'}).alias("csv")).collect() [Row(csv=u'struct<_c0:int,_c1:string>')] """ + if isinstance(col, basestring): + col = _create_column_from_literal(col) + elif isinstance(col, Column): + col = _to_java_column(col) + else: + raise TypeError("schema argument should be a column or string") sc = SparkContext._active_spark_context - jc = sc._jvm.functions.schema_of_csv(_to_java_column(col), options) + jc = sc._jvm.functions.schema_of_csv(col, options) return Column(jc) @@ -2684,7 +2690,7 @@ def from_csv(col, schema, options={}): :param schema: a string with schema in DDL format to use when parsing the CSV column. :param options: options to control parsing. accepts the same options as the CSV datasource - >>> data = [('1,2,3',)] + >>> data = [("1,2,3",)] >>> df = spark.createDataFrame(data, ("value",)) >>> df.select(from_csv(df.value, "a INT, b INT, c INT").alias("csv")).collect() [Row(csv=Row(a=1, b=2, c=3))] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala index 94e249b1468dd..0cb46d984e253 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala @@ -24,14 +24,24 @@ import org.apache.spark.unsafe.types.UTF8String object ExprUtils { - def evalSchemaExpr(exp: Expression): StructType = exp match { - case Literal(s, StringType) => StructType.fromDDL(s.toString) - case e @ SchemaOfCsv(_: Literal, _) => - val ddlSchema = e.eval().asInstanceOf[UTF8String] - StructType.fromDDL(ddlSchema.toString) - case e => throw new AnalysisException( - "Schema should be specified in DDL format as a string literal or output of " + - s"the schema_of_csv function instead of ${e.sql}") + def evalSchemaExpr(exp: Expression): StructType = { + // Use `DataType.fromDDL` since the type string can be struct<...>. + val dataType = exp match { + case Literal(s, StringType) => + DataType.fromDDL(s.toString) + case e @ SchemaOfCsv(_: Literal, _) => + val ddlSchema = e.eval().asInstanceOf[UTF8String] + DataType.fromDDL(ddlSchema.toString) + case e => throw new AnalysisException( + "Schema should be specified in DDL format as a string literal or output of " + + s"the schema_of_csv function instead of ${e.sql}") + } + + if (!dataType.isInstanceOf[StructType]) { + throw new AnalysisException( + s"Schema should be struct type but got ${dataType.sql}.") + } + dataType.asInstanceOf[StructType] } def evalTypeExpr(exp: Expression): DataType = exp match { From e343d4d7dfec12b9aed3954d71871287375dbc62 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 19 Oct 2018 17:18:14 +0800 Subject: [PATCH 21/26] literals only --- .../catalyst/expressions/csvExpressions.scala | 39 +++++++++++------ .../sql-tests/inputs/csv-functions.sql | 7 +++- .../sql-tests/results/csv-functions.sql.out | 42 +++++++++++++++++-- 3 files changed, 70 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala index 6838f914e4204..acf4f71d90e32 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala @@ -21,6 +21,7 @@ import com.univocity.parsers.csv.CsvParser import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.csv._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.util._ @@ -137,7 +138,7 @@ case class CsvToStructs( case class SchemaOfCsv( child: Expression, options: Map[String, String]) - extends UnaryExpression with String2StringExpression with CodegenFallback { + extends UnaryExpression with ExpectsInputTypes with CodegenFallback { def this(child: Expression) = this(child, Map.empty[String, String]) @@ -145,20 +146,32 @@ case class SchemaOfCsv( child = child, options = ExprUtils.convertToMapData(options)) - override def convert(v: UTF8String): UTF8String = { + override def dataType: DataType = StringType + + override def inputTypes: Seq[DataType] = Seq(StringType) + + override def nullable: Boolean = false + + @transient + private lazy val csv = child.eval().asInstanceOf[UTF8String] + + override def checkInputDataTypes(): TypeCheckResult = child match { + case Literal(s, StringType) if s != null => super.checkInputDataTypes() + case _ => TypeCheckResult.TypeCheckFailure( + s"The input csv should be a string literal and not null; however, got ${child.sql}.") + } + + override def eval(v: InternalRow = EmptyRow): Any = { val parsedOptions = new CSVOptions(options, true, "UTC") val parser = new CsvParser(parsedOptions.asParserSettings) - val row = parser.parseLine(v.toString) - - if (row != null) { - val header = row.zipWithIndex.map { case (_, index) => s"_c$index" } - val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType) - val fieldTypes = CSVInferSchema.inferRowType(parsedOptions)(startType, row) - val st = StructType(CSVInferSchema.toStructFields(fieldTypes, header, parsedOptions)) - UTF8String.fromString(st.catalogString) - } else { - null - } + val row = parser.parseLine(csv.toString) + assert(row != null, "Parsed CSV record should not be null.") + + val header = row.zipWithIndex.map { case (_, index) => s"_c$index" } + val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType) + val fieldTypes = CSVInferSchema.inferRowType(parsedOptions)(startType, row) + val st = StructType(CSVInferSchema.toStructFields(fieldTypes, header, parsedOptions)) + UTF8String.fromString(st.catalogString) } override def prettyName: String = "schema_of_csv" diff --git a/sql/core/src/test/resources/sql-tests/inputs/csv-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/csv-functions.sql index 146b8390b3bd5..5be6f807931b8 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/csv-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/csv-functions.sql @@ -8,5 +8,10 @@ select from_csv('1', 'a INT', named_struct('mode', 'PERMISSIVE')); select from_csv('1', 'a INT', map('mode', 1)); select from_csv(); -- infer schema of json literal -select schema_of_csv('1,abc'); +select from_csv('1,abc', schema_of_csv('1,abc')); select schema_of_csv('1|abc', map('delimiter', '|')); +select schema_of_csv(null); +CREATE TEMPORARY VIEW csvTable(csvField, a) AS SELECT * FROM VALUES ('1,abc', 'a'); +SELECT schema_of_csv(csvField) FROM csvTable; +-- Clean up +DROP VIEW IF EXISTS csvTable; diff --git a/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out index 2bee34ca5a079..8b6a0a8805a4a 100644 --- a/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 9 +-- Number of queries: 13 -- !query 0 @@ -70,11 +70,11 @@ Invalid number of arguments for function from_csv. Expected: one of 2 and 3; Fou -- !query 7 -select schema_of_csv('1,abc') +select from_csv('1,abc', schema_of_csv('1,abc')) -- !query 7 schema -struct +struct> -- !query 7 output -struct<_c0:int,_c1:string> +{"_c0":1,"_c1":"abc"} -- !query 8 @@ -83,3 +83,37 @@ select schema_of_csv('1|abc', map('delimiter', '|')) struct -- !query 8 output struct<_c0:int,_c1:string> + + +-- !query 9 +select schema_of_csv(null) +-- !query 9 schema +struct<> +-- !query 9 output +org.apache.spark.sql.AnalysisException +cannot resolve 'schema_of_csv(CAST(NULL AS STRING))' due to data type mismatch: The input csv should be a string literal and not null; however, got CAST(NULL AS STRING).; line 1 pos 7 + + +-- !query 10 +CREATE TEMPORARY VIEW csvTable(csvField, a) AS SELECT * FROM VALUES ('1,abc', 'a') +-- !query 10 schema +struct<> +-- !query 10 output + + + +-- !query 11 +SELECT schema_of_csv(csvField) FROM csvTable +-- !query 11 schema +struct<> +-- !query 11 output +org.apache.spark.sql.AnalysisException +cannot resolve 'schema_of_csv(csvtable.`csvField`)' due to data type mismatch: The input csv should be a string literal and not null; however, got csvtable.`csvField`.; line 1 pos 7 + + +-- !query 12 +DROP VIEW IF EXISTS csvTable +-- !query 12 schema +struct<> +-- !query 12 output + From 26fb354a900940bfee91c7e5c8bc83aa9ed36e39 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 20 Oct 2018 11:34:00 +0800 Subject: [PATCH 22/26] Address comments --- python/pyspark/sql/functions.py | 7 ++++--- .../spark/sql/catalyst/expressions/csvExpressions.scala | 4 ++-- .../sql/catalyst/expressions/CsvExpressionsSuite.scala | 3 +-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 8ac09c00212c2..94a71805cedb6 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2373,9 +2373,10 @@ def schema_of_csv(col, options={}): :param col: string column in CSV format :param options: options to control parsing. accepts the same options as the CSV datasource - >>> data = [(1, '1|a')] - >>> df = spark.createDataFrame(data, ("key", "value")) - >>> df.select(schema_of_csv(df.value, {'sep':'|'}).alias("csv")).collect() + >>> df = spark.range(1) + >>> df.select(schema_of_csv(lit('1|a'), {'sep':'|'}).alias("csv")).collect() + [Row(csv=u'struct<_c0:int,_c1:string>')] + >>> df.select(schema_of_csv('1|a', {'sep':'|'}).alias("csv")).collect() [Row(csv=u'struct<_c0:int,_c1:string>')] """ if isinstance(col, basestring): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala index acf4f71d90e32..7af21efd89ef6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala @@ -63,7 +63,7 @@ case class CsvToStructs( // Used in `FunctionRegistry` def this(child: Expression, schema: Expression, options: Map[String, String]) = this( - schema = ExprUtils.evalSchemaExpr(schema).asInstanceOf[StructType], + schema = ExprUtils.evalSchemaExpr(schema), options = options, child = child, timeZoneId = None) @@ -72,7 +72,7 @@ case class CsvToStructs( def this(child: Expression, schema: Expression, options: Expression) = this( - schema = ExprUtils.evalSchemaExpr(schema).asInstanceOf[StructType], + schema = ExprUtils.evalSchemaExpr(schema), options = ExprUtils.convertToMapData(options), child = child, timeZoneId = None) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala index 6ea0c83d5e45b..386e0d133dff6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala @@ -162,8 +162,7 @@ class CsvExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with P test("infer schema of CSV strings by using options") { checkEvaluation( - new SchemaOfCsv(Literal.create("1|abc"), - CreateMap(Seq(Literal.create("delimiter"), Literal.create("|")))), + new SchemaOfCsv(Literal.create("1|abc"), Map("delimiter" -> "|")), "struct<_c0:int,_c1:string>") } } From b068d9f22157e0529fce87c10c8cca22db243925 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 20 Oct 2018 23:55:08 +0800 Subject: [PATCH 23/26] sync tests --- .../src/test/resources/sql-tests/results/csv-functions.sql.out | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out index 8b6a0a8805a4a..f212936cf6948 100644 --- a/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out @@ -72,7 +72,7 @@ Invalid number of arguments for function from_csv. Expected: one of 2 and 3; Fou -- !query 7 select from_csv('1,abc', schema_of_csv('1,abc')) -- !query 7 schema -struct> +struct> -- !query 7 output {"_c0":1,"_c1":"abc"} From 4696cdd1136ccb68b91a32d0dfd1ecae989bc756 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 26 Oct 2018 11:55:09 +0800 Subject: [PATCH 24/26] match to schema_of_json --- python/pyspark/sql/functions.py | 14 +++++------ .../sql/catalyst/expressions/ExprUtils.scala | 2 +- .../catalyst/expressions/csvExpressions.scala | 6 ++--- .../org/apache/spark/sql/functions.scala | 25 +++++++++++++------ .../apache/spark/sql/CsvFunctionsSuite.scala | 11 +++++++- 5 files changed, 38 insertions(+), 20 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 94a71805cedb6..beb1a065d2803 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2366,11 +2366,11 @@ def schema_of_json(json, options={}): @ignore_unicode_prefix @since(3.0) -def schema_of_csv(col, options={}): +def schema_of_csv(csv, options={}): """ - Parses a column containing a CSV string and infers its schema in DDL format. + Parses a CSV string and infers its schema in DDL format. - :param col: string column in CSV format + :param col: a CSV string or a string literal containing a CSV string. :param options: options to control parsing. accepts the same options as the CSV datasource >>> df = spark.range(1) @@ -2379,10 +2379,10 @@ def schema_of_csv(col, options={}): >>> df.select(schema_of_csv('1|a', {'sep':'|'}).alias("csv")).collect() [Row(csv=u'struct<_c0:int,_c1:string>')] """ - if isinstance(col, basestring): - col = _create_column_from_literal(col) - elif isinstance(col, Column): - col = _to_java_column(col) + if isinstance(csv, basestring): + col = _create_column_from_literal(csv) + elif isinstance(csv, Column): + col = _to_java_column(csv) else: raise TypeError("schema argument should be a column or string") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala index 0cb46d984e253..0397992ef51b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala @@ -30,7 +30,7 @@ object ExprUtils { case Literal(s, StringType) => DataType.fromDDL(s.toString) case e @ SchemaOfCsv(_: Literal, _) => - val ddlSchema = e.eval().asInstanceOf[UTF8String] + val ddlSchema = e.eval(EmptyRow).asInstanceOf[UTF8String] DataType.fromDDL(ddlSchema.toString) case e => throw new AnalysisException( "Schema should be specified in DDL format as a string literal or output of " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala index 7af21efd89ef6..e70296fe31292 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala @@ -138,7 +138,7 @@ case class CsvToStructs( case class SchemaOfCsv( child: Expression, options: Map[String, String]) - extends UnaryExpression with ExpectsInputTypes with CodegenFallback { + extends UnaryExpression with CodegenFallback { def this(child: Expression) = this(child, Map.empty[String, String]) @@ -148,8 +148,6 @@ case class SchemaOfCsv( override def dataType: DataType = StringType - override def inputTypes: Seq[DataType] = Seq(StringType) - override def nullable: Boolean = false @transient @@ -161,7 +159,7 @@ case class SchemaOfCsv( s"The input csv should be a string literal and not null; however, got ${child.sql}.") } - override def eval(v: InternalRow = EmptyRow): Any = { + override def eval(v: InternalRow): Any = { val parsedOptions = new CSVOptions(options, true, "UTC") val parser = new CsvParser(parsedOptions.asParserSettings) val row = parser.parseLine(csv.toString) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index f8d733199e145..0fbc59bd95ee7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3897,19 +3897,30 @@ object functions { } /** - * Parses a column containing a CSV string and infers its schema. + * Parses a CSV string and infers its schema in DDL format. * - * @param e a string column containing CSV data. + * @param csv a CSV string. * * @group collection_funcs * @since 3.0.0 */ - def schema_of_csv(e: Column): Column = withExpr(new SchemaOfCsv(e.expr)) + def schema_of_csv(csv: String): Column = schema_of_csv(lit(csv)) + /** - * Parses a column containing a CSV string and infers its schema using options. + * Parses a CSV string and infers its schema in DDL format. * - * @param e a string column containing CSV data. + * @param csv a string literal containing a CSV string. + * + * @group collection_funcs + * @since 3.0.0 + */ + def schema_of_csv(csv: Column): Column = withExpr(new SchemaOfCsv(csv.expr)) + + /** + * Parses a CSV string and infers its schema in DDL format using options. + * + * @param csv a string literal containing a CSV string. * @param options options to control how the CSV is parsed. accepts the same options and the * json data source. See [[DataFrameReader#csv]]. * @return a column with string literal containing schema in DDL format. @@ -3917,8 +3928,8 @@ object functions { * @group collection_funcs * @since 3.0.0 */ - def schema_of_csv(e: Column, options: java.util.Map[String, String]): Column = { - withExpr(SchemaOfCsv(e.expr, options.asScala.toMap)) + def schema_of_csv(csv: Column, options: java.util.Map[String, String]): Column = { + withExpr(SchemaOfCsv(csv.expr, options.asScala.toMap)) } // scalastyle:off line.size.limit diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala index c237bc2436618..9395f050b41ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala @@ -60,7 +60,16 @@ class CsvFunctionsSuite extends QueryTest with SharedSQLContext { Row(Row(1, java.sql.Date.valueOf("1983-08-04"), null)))) } - test("infers schemas using options") { + test("schema_of_csv - infers schemas") { + checkAnswer( + spark.range(1).select(schema_of_csv(lit("0.1,1"))), + Seq(Row("struct<_c0:double,_c1:int>"))) + checkAnswer( + spark.range(1).select(schema_of_csv("0.1,1")), + Seq(Row("struct<_c0:double,_c1:int>"))) + } + + test("schema_of_csv - infers schemas using options") { val df = spark.range(1) .select(schema_of_csv(lit("0.1 1"), Map("sep" -> " ").asJava)) checkAnswer(df, Seq(Row("struct<_c0:double,_c1:int>"))) From b8c6c948bc4f8a08d01df6c7def86f9843ba71db Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 26 Oct 2018 15:48:25 +0800 Subject: [PATCH 25/26] updates tests --- sql/core/src/main/scala/org/apache/spark/sql/functions.scala | 1 - .../src/test/resources/sql-tests/results/csv-functions.sql.out | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 0fbc59bd95ee7..5c2abadca5424 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3906,7 +3906,6 @@ object functions { */ def schema_of_csv(csv: String): Column = schema_of_csv(lit(csv)) - /** * Parses a CSV string and infers its schema in DDL format. * diff --git a/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out index f212936cf6948..677bbd97c549d 100644 --- a/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out @@ -91,7 +91,7 @@ select schema_of_csv(null) struct<> -- !query 9 output org.apache.spark.sql.AnalysisException -cannot resolve 'schema_of_csv(CAST(NULL AS STRING))' due to data type mismatch: The input csv should be a string literal and not null; however, got CAST(NULL AS STRING).; line 1 pos 7 +cannot resolve 'schema_of_csv(NULL)' due to data type mismatch: The input csv should be a string literal and not null; however, got NULL.; line 1 pos 7 -- !query 10 From 3aa79d4e438a84ea7566f38afd3f2a18fd7cfbed Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 27 Oct 2018 09:52:51 +0800 Subject: [PATCH 26/26] Resolve conflicts --- .../org/apache/spark/sql/catalyst/expressions/ExprUtils.scala | 2 +- .../src/test/resources/sql-tests/results/json-functions.sql.out | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala index 0397992ef51b5..040b56cc1caea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala @@ -47,7 +47,7 @@ object ExprUtils { def evalTypeExpr(exp: Expression): DataType = exp match { case Literal(s, StringType) => DataType.fromDDL(s.toString) case e @ SchemaOfJson(_: Literal, _) => - val ddlSchema = e.eval().asInstanceOf[UTF8String] + val ddlSchema = e.eval(EmptyRow).asInstanceOf[UTF8String] DataType.fromDDL(ddlSchema.toString) case e => throw new AnalysisException( "Schema should be specified in DDL format as a string literal or output of " + diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out index cf34f7c6daf0b..ca0cd90d94fa7 100644 --- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -157,7 +157,7 @@ select from_json() struct<> -- !query 17 output org.apache.spark.sql.AnalysisException -Schema should be specified in DDL format as a string literal or output of the schema_of_json function instead of 1;; line 1 pos 7 +Invalid number of arguments for function from_json. Expected: one of 2 and 3; Found: 0; line 1 pos 7 -- !query 18