diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala index 54af314fe417..5140db90c595 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala @@ -165,10 +165,14 @@ case class SchemaOfCsv( @transient private lazy val csv = child.eval().asInstanceOf[UTF8String] - override def checkInputDataTypes(): TypeCheckResult = child match { - case Literal(s, StringType) if s != null => super.checkInputDataTypes() - case _ => TypeCheckResult.TypeCheckFailure( - s"The input csv should be a string literal and not null; however, got ${child.sql}.") + override def checkInputDataTypes(): TypeCheckResult = { + if (child.foldable && csv != null) { + super.checkInputDataTypes() + } else { + TypeCheckResult.TypeCheckFailure( + "The input csv should be a foldable string expression and not null; " + + s"however, got ${child.sql}.") + } } override def eval(v: InternalRow): Any = { diff --git a/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out index 8495bef9122e..2ce069a62ac7 100644 --- a/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out @@ -91,7 +91,7 @@ select schema_of_csv(null) struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'schema_of_csv(NULL)' due to data type mismatch: The input csv should be a string literal and not null; however, got NULL.; line 1 pos 7 +cannot resolve 'schema_of_csv(NULL)' due to data type mismatch: The input csv should be a foldable string expression and not null; however, got NULL.; line 1 pos 7 -- !query @@ -108,7 +108,7 @@ SELECT schema_of_csv(csvField) FROM csvTable struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'schema_of_csv(csvtable.`csvField`)' due to data type mismatch: The input csv should be a string literal and not null; however, got csvtable.`csvField`.; line 1 pos 7 +cannot resolve 'schema_of_csv(csvtable.`csvField`)' due to data type mismatch: The input csv should be a foldable string expression and not null; however, got csvtable.`csvField`.; line 1 pos 7 -- !query diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala index 61f0e138cc35..aa07815e492e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala @@ -200,4 +200,11 @@ class CsvFunctionsSuite extends QueryTest with SharedSparkSession { assert(readback(0).getAs[Row](0).getAs[Date](0).getTime >= 0) } } + + test("schema_of_csv - infers the schema of foldable CSV string") { + val input = concat_ws(",", lit(0.1), lit(1)) + checkAnswer( + spark.range(1).select(schema_of_csv(input)), + Seq(Row("struct<_c0:double,_c1:int>"))) + } }