Skip to content

Commit 369a148

Browse files
HyukjinKwonbrkyvz
authored andcommitted
[SPARK-19595][SQL] Support json array in from_json
## What changes were proposed in this pull request? This PR proposes to both, **Do not allow json arrays with multiple elements and return null in `from_json` with `StructType` as the schema.** Currently, it only reads the single row when the input is a json array. So, the codes below: ```scala import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ val schema = StructType(StructField("a", IntegerType) :: Nil) Seq(("""[{"a": 1}, {"a": 2}]""")).toDF("struct").select(from_json(col("struct"), schema)).show() ``` prints ``` +--------------------+ |jsontostruct(struct)| +--------------------+ | [1]| +--------------------+ ``` This PR simply suggests to print this as `null` if the schema is `StructType` and input is json array.with multiple elements ``` +--------------------+ |jsontostruct(struct)| +--------------------+ | null| +--------------------+ ``` **Support json arrays in `from_json` with `ArrayType` as the schema.** ```scala import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) Seq(("""[{"a": 1}, {"a": 2}]""")).toDF("array").select(from_json(col("array"), schema)).show() ``` prints ``` +-------------------+ |jsontostruct(array)| +-------------------+ | [[1], [2]]| +-------------------+ ``` ## How was this patch tested? Unit test in `JsonExpressionsSuite`, `JsonFunctionsSuite`, Python doctests and manual test. Author: hyukjinkwon <[email protected]> Closes #16929 from HyukjinKwon/disallow-array.
1 parent 80d5338 commit 369a148

File tree

5 files changed

+186
-17
lines changed

5 files changed

+186
-17
lines changed

python/pyspark/sql/functions.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1773,11 +1773,11 @@ def json_tuple(col, *fields):
17731773
@since(2.1)
17741774
def from_json(col, schema, options={}):
17751775
"""
1776-
Parses a column containing a JSON string into a [[StructType]] with the
1777-
specified schema. Returns `null`, in the case of an unparseable string.
1776+
Parses a column containing a JSON string into a [[StructType]] or [[ArrayType]]
1777+
with the specified schema. Returns `null`, in the case of an unparseable string.
17781778
17791779
:param col: string column in json format
1780-
:param schema: a StructType to use when parsing the json column
1780+
:param schema: a StructType or ArrayType to use when parsing the json column
17811781
:param options: options to control parsing. accepts the same options as the json datasource
17821782
17831783
>>> from pyspark.sql.types import *
@@ -1786,6 +1786,11 @@ def from_json(col, schema, options={}):
17861786
>>> df = spark.createDataFrame(data, ("key", "value"))
17871787
>>> df.select(from_json(df.value, schema).alias("json")).collect()
17881788
[Row(json=Row(a=1))]
1789+
>>> data = [(1, '''[{"a": 1}]''')]
1790+
>>> schema = ArrayType(StructType([StructField("a", IntegerType())]))
1791+
>>> df = spark.createDataFrame(data, ("key", "value"))
1792+
>>> df.select(from_json(df.value, schema).alias("json")).collect()
1793+
[Row(json=[Row(a=1)])]
17891794
"""
17901795

17911796
sc = SparkContext._active_spark_context

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2727
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
2828
import org.apache.spark.sql.catalyst.InternalRow
2929
import org.apache.spark.sql.catalyst.json._
30-
import org.apache.spark.sql.catalyst.util.ParseModes
30+
import org.apache.spark.sql.catalyst.util.{GenericArrayData, ParseModes}
3131
import org.apache.spark.sql.types._
3232
import org.apache.spark.unsafe.types.UTF8String
3333
import org.apache.spark.util.Utils
@@ -480,23 +480,45 @@ case class JsonTuple(children: Seq[Expression])
480480
}
481481

482482
/**
483-
* Converts an json input string to a [[StructType]] with the specified schema.
483+
* Converts an json input string to a [[StructType]] or [[ArrayType]] with the specified schema.
484484
*/
485485
case class JsonToStruct(
486-
schema: StructType,
486+
schema: DataType,
487487
options: Map[String, String],
488488
child: Expression,
489489
timeZoneId: Option[String] = None)
490490
extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes {
491491
override def nullable: Boolean = true
492492

493-
def this(schema: StructType, options: Map[String, String], child: Expression) =
493+
def this(schema: DataType, options: Map[String, String], child: Expression) =
494494
this(schema, options, child, None)
495495

496+
override def checkInputDataTypes(): TypeCheckResult = schema match {
497+
case _: StructType | ArrayType(_: StructType, _) =>
498+
super.checkInputDataTypes()
499+
case _ => TypeCheckResult.TypeCheckFailure(
500+
s"Input schema ${schema.simpleString} must be a struct or an array of structs.")
501+
}
502+
503+
@transient
504+
lazy val rowSchema = schema match {
505+
case st: StructType => st
506+
case ArrayType(st: StructType, _) => st
507+
}
508+
509+
// This converts parsed rows to the desired output by the given schema.
510+
@transient
511+
lazy val converter = schema match {
512+
case _: StructType =>
513+
(rows: Seq[InternalRow]) => if (rows.length == 1) rows.head else null
514+
case ArrayType(_: StructType, _) =>
515+
(rows: Seq[InternalRow]) => new GenericArrayData(rows)
516+
}
517+
496518
@transient
497519
lazy val parser =
498520
new JacksonParser(
499-
schema,
521+
rowSchema,
500522
new JSONOptions(options + ("mode" -> ParseModes.FAIL_FAST_MODE), timeZoneId.get))
501523

502524
override def dataType: DataType = schema
@@ -505,11 +527,32 @@ case class JsonToStruct(
505527
copy(timeZoneId = Option(timeZoneId))
506528

507529
override def nullSafeEval(json: Any): Any = {
530+
// When input is,
531+
// - `null`: `null`.
532+
// - invalid json: `null`.
533+
// - empty string: `null`.
534+
//
535+
// When the schema is array,
536+
// - json array: `Array(Row(...), ...)`
537+
// - json object: `Array(Row(...))`
538+
// - empty json array: `Array()`.
539+
// - empty json object: `Array(Row(null))`.
540+
//
541+
// When the schema is a struct,
542+
// - json object/array with single element: `Row(...)`
543+
// - json array with multiple elements: `null`
544+
// - empty json array: `null`.
545+
// - empty json object: `Row(null)`.
546+
547+
// We need `null` if the input string is an empty string. `JacksonParser` can
548+
// deal with this but produces `Nil`.
549+
if (json.toString.trim.isEmpty) return null
550+
508551
try {
509-
parser.parse(
552+
converter(parser.parse(
510553
json.asInstanceOf[UTF8String],
511554
CreateJacksonParser.utf8String,
512-
identity[UTF8String]).headOption.orNull
555+
identity[UTF8String]))
513556
} catch {
514557
case _: SparkSQLJsonProcessingException => null
515558
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import java.util.Calendar
2222
import org.apache.spark.SparkFunSuite
2323
import org.apache.spark.sql.catalyst.InternalRow
2424
import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, DateTimeUtils, ParseModes}
25-
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType, TimestampType}
25+
import org.apache.spark.sql.types._
2626
import org.apache.spark.unsafe.types.UTF8String
2727

2828
class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -372,6 +372,62 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
372372
)
373373
}
374374

375+
test("from_json - input=array, schema=array, output=array") {
376+
val input = """[{"a": 1}, {"a": 2}]"""
377+
val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil))
378+
val output = InternalRow(1) :: InternalRow(2) :: Nil
379+
checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output)
380+
}
381+
382+
test("from_json - input=object, schema=array, output=array of single row") {
383+
val input = """{"a": 1}"""
384+
val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil))
385+
val output = InternalRow(1) :: Nil
386+
checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output)
387+
}
388+
389+
test("from_json - input=empty array, schema=array, output=empty array") {
390+
val input = "[ ]"
391+
val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil))
392+
val output = Nil
393+
checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output)
394+
}
395+
396+
test("from_json - input=empty object, schema=array, output=array of single row with null") {
397+
val input = "{ }"
398+
val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil))
399+
val output = InternalRow(null) :: Nil
400+
checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output)
401+
}
402+
403+
test("from_json - input=array of single object, schema=struct, output=single row") {
404+
val input = """[{"a": 1}]"""
405+
val schema = StructType(StructField("a", IntegerType) :: Nil)
406+
val output = InternalRow(1)
407+
checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output)
408+
}
409+
410+
test("from_json - input=array, schema=struct, output=null") {
411+
val input = """[{"a": 1}, {"a": 2}]"""
412+
val schema = StructType(StructField("a", IntegerType) :: Nil)
413+
val output = null
414+
checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output)
415+
}
416+
417+
test("from_json - input=empty array, schema=struct, output=null") {
418+
val input = """[]"""
419+
val schema = StructType(StructField("a", IntegerType) :: Nil)
420+
val output = null
421+
checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output)
422+
}
423+
424+
test("from_json - input=empty object, schema=struct, output=single row with null") {
425+
val input = """{ }"""
426+
val schema = StructType(StructField("a", IntegerType) :: Nil)
427+
val output = InternalRow(null)
428+
checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output)
429+
}
430+
375431
test("from_json null input column") {
376432
val schema = StructType(StructField("a", IntegerType) :: Nil)
377433
checkEvaluation(

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2973,7 +2973,22 @@ object functions {
29732973
* @group collection_funcs
29742974
* @since 2.1.0
29752975
*/
2976-
def from_json(e: Column, schema: StructType, options: Map[String, String]): Column = withExpr {
2976+
def from_json(e: Column, schema: StructType, options: Map[String, String]): Column =
2977+
from_json(e, schema.asInstanceOf[DataType], options)
2978+
2979+
/**
2980+
* (Scala-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType`
2981+
* with the specified schema. Returns `null`, in the case of an unparseable string.
2982+
*
2983+
* @param e a string column containing JSON data.
2984+
* @param schema the schema to use when parsing the json string
2985+
* @param options options to control how the json is parsed. accepts the same options and the
2986+
* json data source.
2987+
*
2988+
* @group collection_funcs
2989+
* @since 2.2.0
2990+
*/
2991+
def from_json(e: Column, schema: DataType, options: Map[String, String]): Column = withExpr {
29772992
JsonToStruct(schema, options, e.expr)
29782993
}
29792994

@@ -2992,6 +3007,21 @@ object functions {
29923007
def from_json(e: Column, schema: StructType, options: java.util.Map[String, String]): Column =
29933008
from_json(e, schema, options.asScala.toMap)
29943009

3010+
/**
3011+
* (Java-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType`
3012+
* with the specified schema. Returns `null`, in the case of an unparseable string.
3013+
*
3014+
* @param e a string column containing JSON data.
3015+
* @param schema the schema to use when parsing the json string
3016+
* @param options options to control how the json is parsed. accepts the same options and the
3017+
* json data source.
3018+
*
3019+
* @group collection_funcs
3020+
* @since 2.2.0
3021+
*/
3022+
def from_json(e: Column, schema: DataType, options: java.util.Map[String, String]): Column =
3023+
from_json(e, schema, options.asScala.toMap)
3024+
29953025
/**
29963026
* Parses a column containing a JSON string into a `StructType` with the specified schema.
29973027
* Returns `null`, in the case of an unparseable string.
@@ -3006,8 +3036,21 @@ object functions {
30063036
from_json(e, schema, Map.empty[String, String])
30073037

30083038
/**
3009-
* Parses a column containing a JSON string into a `StructType` with the specified schema.
3010-
* Returns `null`, in the case of an unparseable string.
3039+
* Parses a column containing a JSON string into a `StructType` or `ArrayType`
3040+
* with the specified schema. Returns `null`, in the case of an unparseable string.
3041+
*
3042+
* @param e a string column containing JSON data.
3043+
* @param schema the schema to use when parsing the json string
3044+
*
3045+
* @group collection_funcs
3046+
* @since 2.2.0
3047+
*/
3048+
def from_json(e: Column, schema: DataType): Column =
3049+
from_json(e, schema, Map.empty[String, String])
3050+
3051+
/**
3052+
* Parses a column containing a JSON string into a `StructType` or `ArrayType`
3053+
* with the specified schema. Returns `null`, in the case of an unparseable string.
30113054
*
30123055
* @param e a string column containing JSON data.
30133056
* @param schema the schema to use when parsing the json string as a json string
@@ -3016,8 +3059,7 @@ object functions {
30163059
* @since 2.1.0
30173060
*/
30183061
def from_json(e: Column, schema: String, options: java.util.Map[String, String]): Column =
3019-
from_json(e, DataType.fromJson(schema).asInstanceOf[StructType], options)
3020-
3062+
from_json(e, DataType.fromJson(schema), options)
30213063

30223064
/**
30233065
* (Scala-specific) Converts a column containing a `StructType` into a JSON string with the

sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.sql
1919

2020
import org.apache.spark.sql.functions.{from_json, struct, to_json}
2121
import org.apache.spark.sql.test.SharedSQLContext
22-
import org.apache.spark.sql.types.{CalendarIntervalType, IntegerType, StructType, TimestampType}
22+
import org.apache.spark.sql.types._
2323

2424
class JsonFunctionsSuite extends QueryTest with SharedSQLContext {
2525
import testImplicits._
@@ -133,6 +133,29 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext {
133133
Row(null) :: Nil)
134134
}
135135

136+
test("from_json invalid schema") {
137+
val df = Seq("""{"a" 1}""").toDS()
138+
val schema = ArrayType(StringType)
139+
val message = intercept[AnalysisException] {
140+
df.select(from_json($"value", schema))
141+
}.getMessage
142+
143+
assert(message.contains(
144+
"Input schema array<string> must be a struct or an array of structs."))
145+
}
146+
147+
test("from_json array support") {
148+
val df = Seq("""[{"a": 1, "b": "a"}, {"a": 2}, { }]""").toDS()
149+
val schema = ArrayType(
150+
StructType(
151+
StructField("a", IntegerType) ::
152+
StructField("b", StringType) :: Nil))
153+
154+
checkAnswer(
155+
df.select(from_json($"value", schema)),
156+
Row(Seq(Row(1, "a"), Row(2, null), Row(null, null))))
157+
}
158+
136159
test("to_json") {
137160
val df = Seq(Tuple1(Tuple1(1))).toDF("a")
138161

0 commit comments

Comments
 (0)