diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 426a4a8c93a6..376b86ea69bd 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1773,11 +1773,11 @@ def json_tuple(col, *fields): @since(2.1) def from_json(col, schema, options={}): """ - Parses a column containing a JSON string into a [[StructType]] with the - specified schema. Returns `null`, in the case of an unparseable string. + Parses a column containing a JSON string into a [[StructType]] or [[ArrayType]] + with the specified schema. Returns `null`, in the case of an unparseable string. :param col: string column in json format - :param schema: a StructType to use when parsing the json column + :param schema: a StructType or ArrayType to use when parsing the json column :param options: options to control parsing. accepts the same options as the json datasource >>> from pyspark.sql.types import * @@ -1786,6 +1786,11 @@ def from_json(col, schema, options={}): >>> df = spark.createDataFrame(data, ("key", "value")) >>> df.select(from_json(df.value, schema).alias("json")).collect() [Row(json=Row(a=1))] + >>> data = [(1, '''[{"a": 1}]''')] + >>> schema = ArrayType(StructType([StructField("a", IntegerType())])) + >>> df = spark.createDataFrame(data, ("key", "value")) + >>> df.select(from_json(df.value, schema).alias("json")).collect() + [Row(json=[Row(a=1)])] """ sc = SparkContext._active_spark_context diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 1e690a446951..dbff62efdddb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json._ -import org.apache.spark.sql.catalyst.util.ParseModes +import org.apache.spark.sql.catalyst.util.{GenericArrayData, ParseModes} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -480,23 +480,45 @@ case class JsonTuple(children: Seq[Expression]) } /** - * Converts an json input string to a [[StructType]] with the specified schema. + * Converts an json input string to a [[StructType]] or [[ArrayType]] with the specified schema. */ case class JsonToStruct( - schema: StructType, + schema: DataType, options: Map[String, String], child: Expression, timeZoneId: Option[String] = None) extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes { override def nullable: Boolean = true - def this(schema: StructType, options: Map[String, String], child: Expression) = + def this(schema: DataType, options: Map[String, String], child: Expression) = this(schema, options, child, None) + override def checkInputDataTypes(): TypeCheckResult = schema match { + case _: StructType | ArrayType(_: StructType, _) => + super.checkInputDataTypes() + case _ => TypeCheckResult.TypeCheckFailure( + s"Input schema ${schema.simpleString} must be a struct or an array of structs.") + } + + @transient + lazy val rowSchema = schema match { + case st: StructType => st + case ArrayType(st: StructType, _) => st + } + + // This converts parsed rows to the desired output by the given schema. + @transient + lazy val converter = schema match { + case _: StructType => + (rows: Seq[InternalRow]) => if (rows.length == 1) rows.head else null + case ArrayType(_: StructType, _) => + (rows: Seq[InternalRow]) => new GenericArrayData(rows) + } + @transient lazy val parser = new JacksonParser( - schema, + rowSchema, new JSONOptions(options + ("mode" -> ParseModes.FAIL_FAST_MODE), timeZoneId.get)) override def dataType: DataType = schema @@ -505,11 +527,32 @@ case class JsonToStruct( copy(timeZoneId = Option(timeZoneId)) override def nullSafeEval(json: Any): Any = { + // When input is, + // - `null`: `null`. + // - invalid json: `null`. + // - empty string: `null`. + // + // When the schema is array, + // - json array: `Array(Row(...), ...)` + // - json object: `Array(Row(...))` + // - empty json array: `Array()`. + // - empty json object: `Array(Row(null))`. + // + // When the schema is a struct, + // - json object/array with single element: `Row(...)` + // - json array with multiple elements: `null` + // - empty json array: `null`. + // - empty json object: `Row(null)`. + + // We need `null` if the input string is an empty string. `JacksonParser` can + // deal with this but produces `Nil`. + if (json.toString.trim.isEmpty) return null + try { - parser.parse( + converter(parser.parse( json.asInstanceOf[UTF8String], CreateJacksonParser.utf8String, - identity[UTF8String]).headOption.orNull + identity[UTF8String])) } catch { case _: SparkSQLJsonProcessingException => null } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 0c46819cdb9c..e3584909ddc4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -22,7 +22,7 @@ import java.util.Calendar import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, DateTimeUtils, ParseModes} -import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType, TimestampType} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -372,6 +372,62 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { ) } + test("from_json - input=array, schema=array, output=array") { + val input = """[{"a": 1}, {"a": 2}]""" + val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) + val output = InternalRow(1) :: InternalRow(2) :: Nil + checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output) + } + + test("from_json - input=object, schema=array, output=array of single row") { + val input = """{"a": 1}""" + val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) + val output = InternalRow(1) :: Nil + checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output) + } + + test("from_json - input=empty array, schema=array, output=empty array") { + val input = "[ ]" + val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) + val output = Nil + checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output) + } + + test("from_json - input=empty object, schema=array, output=array of single row with null") { + val input = "{ }" + val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) + val output = InternalRow(null) :: Nil + checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output) + } + + test("from_json - input=array of single object, schema=struct, output=single row") { + val input = """[{"a": 1}]""" + val schema = StructType(StructField("a", IntegerType) :: Nil) + val output = InternalRow(1) + checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output) + } + + test("from_json - input=array, schema=struct, output=null") { + val input = """[{"a": 1}, {"a": 2}]""" + val schema = StructType(StructField("a", IntegerType) :: Nil) + val output = null + checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output) + } + + test("from_json - input=empty array, schema=struct, output=null") { + val input = """[]""" + val schema = StructType(StructField("a", IntegerType) :: Nil) + val output = null + checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output) + } + + test("from_json - input=empty object, schema=struct, output=single row with null") { + val input = """{ }""" + val schema = StructType(StructField("a", IntegerType) :: Nil) + val output = InternalRow(null) + checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output) + } + test("from_json null input column") { val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 24ed906d3368..0e048aa4c05b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2964,7 +2964,22 @@ object functions { * @group collection_funcs * @since 2.1.0 */ - def from_json(e: Column, schema: StructType, options: Map[String, String]): Column = withExpr { + def from_json(e: Column, schema: StructType, options: Map[String, String]): Column = + from_json(e, schema.asInstanceOf[DataType], options) + + /** + * (Scala-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType` + * with the specified schema. Returns `null`, in the case of an unparseable string. + * + * @param e a string column containing JSON data. + * @param schema the schema to use when parsing the json string + * @param options options to control how the json is parsed. accepts the same options and the + * json data source. + * + * @group collection_funcs + * @since 2.2.0 + */ + def from_json(e: Column, schema: DataType, options: Map[String, String]): Column = withExpr { JsonToStruct(schema, options, e.expr) } @@ -2983,6 +2998,21 @@ object functions { def from_json(e: Column, schema: StructType, options: java.util.Map[String, String]): Column = from_json(e, schema, options.asScala.toMap) + /** + * (Java-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType` + * with the specified schema. Returns `null`, in the case of an unparseable string. + * + * @param e a string column containing JSON data. + * @param schema the schema to use when parsing the json string + * @param options options to control how the json is parsed. accepts the same options and the + * json data source. + * + * @group collection_funcs + * @since 2.2.0 + */ + def from_json(e: Column, schema: DataType, options: java.util.Map[String, String]): Column = + from_json(e, schema, options.asScala.toMap) + /** * Parses a column containing a JSON string into a `StructType` with the specified schema. * Returns `null`, in the case of an unparseable string. @@ -2997,8 +3027,21 @@ object functions { from_json(e, schema, Map.empty[String, String]) /** - * Parses a column containing a JSON string into a `StructType` with the specified schema. - * Returns `null`, in the case of an unparseable string. + * Parses a column containing a JSON string into a `StructType` or `ArrayType` + * with the specified schema. Returns `null`, in the case of an unparseable string. + * + * @param e a string column containing JSON data. + * @param schema the schema to use when parsing the json string + * + * @group collection_funcs + * @since 2.2.0 + */ + def from_json(e: Column, schema: DataType): Column = + from_json(e, schema, Map.empty[String, String]) + + /** + * Parses a column containing a JSON string into a `StructType` or `ArrayType` + * with the specified schema. Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. * @param schema the schema to use when parsing the json string as a json string @@ -3007,8 +3050,7 @@ object functions { * @since 2.1.0 */ def from_json(e: Column, schema: String, options: java.util.Map[String, String]): Column = - from_json(e, DataType.fromJson(schema).asInstanceOf[StructType], options) - + from_json(e, DataType.fromJson(schema), options) /** * (Scala-specific) Converts a column containing a `StructType` into a JSON string with the diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 9c39b3c7f09b..953d161ec2a1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import org.apache.spark.sql.functions.{from_json, struct, to_json} import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{CalendarIntervalType, IntegerType, StructType, TimestampType} +import org.apache.spark.sql.types._ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -133,6 +133,29 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { Row(null) :: Nil) } + test("from_json invalid schema") { + val df = Seq("""{"a" 1}""").toDS() + val schema = ArrayType(StringType) + val message = intercept[AnalysisException] { + df.select(from_json($"value", schema)) + }.getMessage + + assert(message.contains( + "Input schema array must be a struct or an array of structs.")) + } + + test("from_json array support") { + val df = Seq("""[{"a": 1, "b": "a"}, {"a": 2}, { }]""").toDS() + val schema = ArrayType( + StructType( + StructField("a", IntegerType) :: + StructField("b", StringType) :: Nil)) + + checkAnswer( + df.select(from_json($"value", schema)), + Row(Seq(Row(1, "a"), Row(2, null), Row(null, null)))) + } + test("to_json") { val df = Seq(Tuple1(Tuple1(1))).toDF("a")