diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index ca2a256983d67..beb1a065d2803 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2364,6 +2364,33 @@ def schema_of_json(json, options={}): return Column(jc) +@ignore_unicode_prefix +@since(3.0) +def schema_of_csv(csv, options={}): + """ + Parses a CSV string and infers its schema in DDL format. + + :param col: a CSV string or a string literal containing a CSV string. + :param options: options to control parsing. accepts the same options as the CSV datasource + + >>> df = spark.range(1) + >>> df.select(schema_of_csv(lit('1|a'), {'sep':'|'}).alias("csv")).collect() + [Row(csv=u'struct<_c0:int,_c1:string>')] + >>> df.select(schema_of_csv('1|a', {'sep':'|'}).alias("csv")).collect() + [Row(csv=u'struct<_c0:int,_c1:string>')] + """ + if isinstance(csv, basestring): + col = _create_column_from_literal(csv) + elif isinstance(csv, Column): + col = _to_java_column(csv) + else: + raise TypeError("schema argument should be a column or string") + + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.schema_of_csv(col, options) + return Column(jc) + + @since(1.5) def size(col): """ @@ -2664,13 +2691,13 @@ def from_csv(col, schema, options={}): :param schema: a string with schema in DDL format to use when parsing the CSV column. :param options: options to control parsing. accepts the same options as the CSV datasource - >>> data = [(1, '1')] - >>> df = spark.createDataFrame(data, ("key", "value")) - >>> df.select(from_csv(df.value, "a INT").alias("csv")).collect() - [Row(csv=Row(a=1))] - >>> df = spark.createDataFrame(data, ("key", "value")) - >>> df.select(from_csv(df.value, lit("a INT")).alias("csv")).collect() - [Row(csv=Row(a=1))] + >>> data = [("1,2,3",)] + >>> df = spark.createDataFrame(data, ("value",)) + >>> df.select(from_csv(df.value, "a INT, b INT, c INT").alias("csv")).collect() + [Row(csv=Row(a=1, b=2, c=3))] + >>> value = data[0][0] + >>> df.select(from_csv(df.value, schema_of_csv(value)).alias("csv")).collect() + [Row(csv=Row(_c0=1, _c1=2, _c2=3))] """ sc = SparkContext._active_spark_context diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 38f5c02910f79..9ebb44d14d678 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -523,7 +523,8 @@ object FunctionRegistry { castAlias("string", StringType), // csv - expression[CsvToStructs]("from_csv") + expression[CsvToStructs]("from_csv"), + expression[SchemaOfCsv]("schema_of_csv") ) val builtin: SimpleFunctionRegistry = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala similarity index 92% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala index 4326a186d6d5f..799e9994451b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala @@ -15,19 +15,18 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.datasources.csv +package org.apache.spark.sql.catalyst.csv import java.math.BigDecimal -import scala.util.control.Exception._ +import scala.util.control.Exception.allCatch import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.TypeCoercion -import org.apache.spark.sql.catalyst.csv.CSVOptions import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ -private[csv] object CSVInferSchema { +object CSVInferSchema { /** * Similar to the JSON schema inference @@ -44,13 +43,7 @@ private[csv] object CSVInferSchema { val rootTypes: Array[DataType] = tokenRDD.aggregate(startType)(inferRowType(options), mergeRowTypes) - header.zip(rootTypes).map { case (thisHeader, rootType) => - val dType = rootType match { - case _: NullType => StringType - case other => other - } - StructField(thisHeader, dType, nullable = true) - } + toStructFields(rootTypes, header, options) } else { // By default fields are assumed to be StringType header.map(fieldName => StructField(fieldName, StringType, nullable = true)) @@ -59,7 +52,20 @@ private[csv] object CSVInferSchema { StructType(fields) } - private def inferRowType(options: CSVOptions) + def toStructFields( + fieldTypes: Array[DataType], + header: Array[String], + options: CSVOptions): Array[StructField] = { + header.zip(fieldTypes).map { case (thisHeader, rootType) => + val dType = rootType match { + case _: NullType => StringType + case other => other + } + StructField(thisHeader, dType, nullable = true) + } + } + + def inferRowType(options: CSVOptions) (rowSoFar: Array[DataType], next: Array[String]): Array[DataType] = { var i = 0 while (i < math.min(rowSoFar.length, next.length)) { // May have columns on right missing. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala index e5708894f22b4..040b56cc1caea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala @@ -19,14 +19,39 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.util.ArrayBasedMapData -import org.apache.spark.sql.types.{MapType, StringType, StructType} +import org.apache.spark.sql.types.{DataType, MapType, StringType, StructType} +import org.apache.spark.unsafe.types.UTF8String object ExprUtils { - def evalSchemaExpr(exp: Expression): StructType = exp match { - case Literal(s, StringType) => StructType.fromDDL(s.toString) + def evalSchemaExpr(exp: Expression): StructType = { + // Use `DataType.fromDDL` since the type string can be struct<...>. + val dataType = exp match { + case Literal(s, StringType) => + DataType.fromDDL(s.toString) + case e @ SchemaOfCsv(_: Literal, _) => + val ddlSchema = e.eval(EmptyRow).asInstanceOf[UTF8String] + DataType.fromDDL(ddlSchema.toString) + case e => throw new AnalysisException( + "Schema should be specified in DDL format as a string literal or output of " + + s"the schema_of_csv function instead of ${e.sql}") + } + + if (!dataType.isInstanceOf[StructType]) { + throw new AnalysisException( + s"Schema should be struct type but got ${dataType.sql}.") + } + dataType.asInstanceOf[StructType] + } + + def evalTypeExpr(exp: Expression): DataType = exp match { + case Literal(s, StringType) => DataType.fromDDL(s.toString) + case e @ SchemaOfJson(_: Literal, _) => + val ddlSchema = e.eval(EmptyRow).asInstanceOf[UTF8String] + DataType.fromDDL(ddlSchema.toString) case e => throw new AnalysisException( - s"Schema should be specified in DDL format as a string literal instead of ${e.sql}") + "Schema should be specified in DDL format as a string literal or output of " + + s"the schema_of_json function instead of ${e.sql}") } def convertToMapData(exp: Expression): Map[String, String] = exp match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala index 853b1ea6a5f1c..e70296fe31292 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala @@ -17,8 +17,11 @@ package org.apache.spark.sql.catalyst.expressions +import com.univocity.parsers.csv.CsvParser + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.csv._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.util._ @@ -120,3 +123,54 @@ case class CsvToStructs( override def prettyName: String = "from_csv" } + +/** + * A function infers schema of CSV string. + */ +@ExpressionDescription( + usage = "_FUNC_(csv[, options]) - Returns schema in the DDL format of CSV string.", + examples = """ + Examples: + > SELECT _FUNC_('1,abc'); + struct<_c0:int,_c1:string> + """, + since = "3.0.0") +case class SchemaOfCsv( + child: Expression, + options: Map[String, String]) + extends UnaryExpression with CodegenFallback { + + def this(child: Expression) = this(child, Map.empty[String, String]) + + def this(child: Expression, options: Expression) = this( + child = child, + options = ExprUtils.convertToMapData(options)) + + override def dataType: DataType = StringType + + override def nullable: Boolean = false + + @transient + private lazy val csv = child.eval().asInstanceOf[UTF8String] + + override def checkInputDataTypes(): TypeCheckResult = child match { + case Literal(s, StringType) if s != null => super.checkInputDataTypes() + case _ => TypeCheckResult.TypeCheckFailure( + s"The input csv should be a string literal and not null; however, got ${child.sql}.") + } + + override def eval(v: InternalRow): Any = { + val parsedOptions = new CSVOptions(options, true, "UTC") + val parser = new CsvParser(parsedOptions.asParserSettings) + val row = parser.parseLine(csv.toString) + assert(row != null, "Parsed CSV record should not be null.") + + val header = row.zipWithIndex.map { case (_, index) => s"_c$index" } + val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType) + val fieldTypes = CSVInferSchema.inferRowType(parsedOptions)(startType, row) + val st = StructType(CSVInferSchema.toStructFields(fieldTypes, header, parsedOptions)) + UTF8String.fromString(st.catalogString) + } + + override def prettyName: String = "schema_of_csv" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 77af5906010f3..eafcb6161036e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -529,7 +529,7 @@ case class JsonToStructs( // Used in `FunctionRegistry` def this(child: Expression, schema: Expression, options: Map[String, String]) = this( - schema = JsonExprUtils.evalSchemaExpr(schema), + schema = ExprUtils.evalTypeExpr(schema), options = options, child = child, timeZoneId = None) @@ -538,7 +538,7 @@ case class JsonToStructs( def this(child: Expression, schema: Expression, options: Expression) = this( - schema = JsonExprUtils.evalSchemaExpr(schema), + schema = ExprUtils.evalTypeExpr(schema), options = ExprUtils.convertToMapData(options), child = child, timeZoneId = None) @@ -784,15 +784,3 @@ case class SchemaOfJson( override def prettyName: String = "schema_of_json" } - -object JsonExprUtils { - def evalSchemaExpr(exp: Expression): DataType = exp match { - case Literal(s, StringType) => DataType.fromDDL(s.toString) - case e @ SchemaOfJson(_: Literal, _) => - val ddlSchema = e.eval(EmptyRow).asInstanceOf[UTF8String] - DataType.fromDDL(ddlSchema.toString) - case e => throw new AnalysisException( - "Schema should be specified in DDL format as a string literal" + - s" or output of the schema_of_json function instead of ${e.sql}") - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala similarity index 98% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala index 6b64f2ffa98dd..651846d2ebcb5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala @@ -15,10 +15,9 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.datasources.csv +package org.apache.spark.sql.catalyst.csv import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.csv.CSVOptions import org.apache.spark.sql.types._ class CSVInferSchemaSuite extends SparkFunSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala similarity index 98% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala index 6f231142949d1..e4e7dc2e8c0e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala @@ -15,12 +15,11 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.datasources.csv +package org.apache.spark.sql.catalyst.csv import java.math.BigDecimal import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.csv.{CSVOptions, UnivocityParser} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala index 65987af710750..386e0d133dff6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala @@ -155,4 +155,14 @@ class CsvExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with P }.getCause assert(exception.getMessage.contains("from_csv() doesn't support the DROPMALFORMED mode")) } + + test("infer schema of CSV strings") { + checkEvaluation(new SchemaOfCsv(Literal.create("1,abc")), "struct<_c0:int,_c1:string>") + } + + test("infer schema of CSV strings by using options") { + checkEvaluation( + new SchemaOfCsv(Literal.create("1|abc"), Map("delimiter" -> "|")), + "struct<_c0:int,_c1:string>") + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 9e7b45db9f280..4808e8ef042d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -34,7 +34,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, UnivocityParser} +import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVInferSchema, CSVOptions, UnivocityParser} import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.types.StructType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 757a3226855c5..5c2abadca5424 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3896,6 +3896,41 @@ object functions { withExpr(new CsvToStructs(e.expr, schema.expr, options.asScala.toMap)) } + /** + * Parses a CSV string and infers its schema in DDL format. + * + * @param csv a CSV string. + * + * @group collection_funcs + * @since 3.0.0 + */ + def schema_of_csv(csv: String): Column = schema_of_csv(lit(csv)) + + /** + * Parses a CSV string and infers its schema in DDL format. + * + * @param csv a string literal containing a CSV string. + * + * @group collection_funcs + * @since 3.0.0 + */ + def schema_of_csv(csv: Column): Column = withExpr(new SchemaOfCsv(csv.expr)) + + /** + * Parses a CSV string and infers its schema in DDL format using options. + * + * @param csv a string literal containing a CSV string. + * @param options options to control how the CSV is parsed. accepts the same options and the + * json data source. See [[DataFrameReader#csv]]. + * @return a column with string literal containing schema in DDL format. + * + * @group collection_funcs + * @since 3.0.0 + */ + def schema_of_csv(csv: Column, options: java.util.Map[String, String]): Column = { + withExpr(SchemaOfCsv(csv.expr, options.asScala.toMap)) + } + // scalastyle:off line.size.limit // scalastyle:off parameter.number diff --git a/sql/core/src/test/resources/sql-tests/inputs/csv-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/csv-functions.sql index d2214fd016028..5be6f807931b8 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/csv-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/csv-functions.sql @@ -7,3 +7,11 @@ select from_csv('1', 'a InvalidType'); select from_csv('1', 'a INT', named_struct('mode', 'PERMISSIVE')); select from_csv('1', 'a INT', map('mode', 1)); select from_csv(); +-- infer schema of json literal +select from_csv('1,abc', schema_of_csv('1,abc')); +select schema_of_csv('1|abc', map('delimiter', '|')); +select schema_of_csv(null); +CREATE TEMPORARY VIEW csvTable(csvField, a) AS SELECT * FROM VALUES ('1,abc', 'a'); +SELECT schema_of_csv(csvField) FROM csvTable; +-- Clean up +DROP VIEW IF EXISTS csvTable; diff --git a/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out index f19f34a773c16..677bbd97c549d 100644 --- a/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 7 +-- Number of queries: 13 -- !query 0 @@ -24,7 +24,7 @@ select from_csv('1', 1) struct<> -- !query 2 output org.apache.spark.sql.AnalysisException -Schema should be specified in DDL format as a string literal instead of 1;; line 1 pos 7 +Schema should be specified in DDL format as a string literal or output of the schema_of_csv function instead of 1;; line 1 pos 7 -- !query 3 @@ -67,3 +67,53 @@ struct<> -- !query 6 output org.apache.spark.sql.AnalysisException Invalid number of arguments for function from_csv. Expected: one of 2 and 3; Found: 0; line 1 pos 7 + + +-- !query 7 +select from_csv('1,abc', schema_of_csv('1,abc')) +-- !query 7 schema +struct> +-- !query 7 output +{"_c0":1,"_c1":"abc"} + + +-- !query 8 +select schema_of_csv('1|abc', map('delimiter', '|')) +-- !query 8 schema +struct +-- !query 8 output +struct<_c0:int,_c1:string> + + +-- !query 9 +select schema_of_csv(null) +-- !query 9 schema +struct<> +-- !query 9 output +org.apache.spark.sql.AnalysisException +cannot resolve 'schema_of_csv(NULL)' due to data type mismatch: The input csv should be a string literal and not null; however, got NULL.; line 1 pos 7 + + +-- !query 10 +CREATE TEMPORARY VIEW csvTable(csvField, a) AS SELECT * FROM VALUES ('1,abc', 'a') +-- !query 10 schema +struct<> +-- !query 10 output + + + +-- !query 11 +SELECT schema_of_csv(csvField) FROM csvTable +-- !query 11 schema +struct<> +-- !query 11 output +org.apache.spark.sql.AnalysisException +cannot resolve 'schema_of_csv(csvtable.`csvField`)' due to data type mismatch: The input csv should be a string literal and not null; however, got csvtable.`csvField`.; line 1 pos 7 + + +-- !query 12 +DROP VIEW IF EXISTS csvTable +-- !query 12 schema +struct<> +-- !query 12 output + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala index 38a2143d6d0f0..9395f050b41ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala @@ -59,4 +59,19 @@ class CsvFunctionsSuite extends QueryTest with SharedSQLContext { Row(Row(null, null, "0,2013-111-11 12:13:14")), Row(Row(1, java.sql.Date.valueOf("1983-08-04"), null)))) } + + test("schema_of_csv - infers schemas") { + checkAnswer( + spark.range(1).select(schema_of_csv(lit("0.1,1"))), + Seq(Row("struct<_c0:double,_c1:int>"))) + checkAnswer( + spark.range(1).select(schema_of_csv("0.1,1")), + Seq(Row("struct<_c0:double,_c1:int>"))) + } + + test("schema_of_csv - infers schemas using options") { + val df = spark.range(1) + .select(schema_of_csv(lit("0.1 1"), Map("sep" -> " ").asJava)) + checkAnswer(df, Seq(Row("struct<_c0:double,_c1:int>"))) + } }