Skip to content

Commit 3eda057

Browse files
HyukjinKwonmarmbrus
authored andcommitted
[SPARK-18295][SQL] Make to_json function null safe (matching it to from_json)
## What changes were proposed in this pull request? This PR proposes to match up the behaviour of `to_json` to `from_json` function for null-safety. Currently, it throws `NullPointException` but this PR fixes this to produce `null` instead. with the data below: ```scala import spark.implicits._ val df = Seq(Some(Tuple1(Tuple1(1))), None).toDF("a") df.show() ``` ``` +----+ | a| +----+ | [1]| |null| +----+ ``` the codes below ```scala import org.apache.spark.sql.functions._ df.select(to_json($"a")).show() ``` produces.. **Before** throws `NullPointException` as below: ``` java.lang.NullPointerException at org.apache.spark.sql.catalyst.json.JacksonGenerator.org$apache$spark$sql$catalyst$json$JacksonGenerator$$writeFields(JacksonGenerator.scala:138) at org.apache.spark.sql.catalyst.json.JacksonGenerator$$anonfun$write$1.apply$mcV$sp(JacksonGenerator.scala:194) at org.apache.spark.sql.catalyst.json.JacksonGenerator.org$apache$spark$sql$catalyst$json$JacksonGenerator$$writeObject(JacksonGenerator.scala:131) at org.apache.spark.sql.catalyst.json.JacksonGenerator.write(JacksonGenerator.scala:193) at org.apache.spark.sql.catalyst.expressions.StructToJson.eval(jsonExpressions.scala:544) at org.apache.spark.sql.catalyst.expressions.Alias.eval(namedExpressions.scala:142) at org.apache.spark.sql.catalyst.expressions.InterpretedProjection.apply(Projection.scala:48) at org.apache.spark.sql.catalyst.expressions.InterpretedProjection.apply(Projection.scala:30) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) ``` **After** ``` +---------------+ |structtojson(a)| +---------------+ | {"_1":1}| | null| +---------------+ ``` ## How was this patch tested? Unit test in `JsonExpressionsSuite.scala` and `JsonFunctionsSuite.scala`. Author: hyukjinkwon <[email protected]> Closes #15792 from HyukjinKwon/SPARK-18295.
1 parent 3a710b9 commit 3eda057

File tree

3 files changed

+30
-11
lines changed

3 files changed

+30
-11
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ case class JsonTuple(children: Seq[Expression])
484484
* Converts an json input string to a [[StructType]] with the specified schema.
485485
*/
486486
case class JsonToStruct(schema: StructType, options: Map[String, String], child: Expression)
487-
extends Expression with CodegenFallback with ExpectsInputTypes {
487+
extends UnaryExpression with CodegenFallback with ExpectsInputTypes {
488488
override def nullable: Boolean = true
489489

490490
@transient
@@ -495,11 +495,8 @@ case class JsonToStruct(schema: StructType, options: Map[String, String], child:
495495
new JSONOptions(options ++ Map("mode" -> ParseModes.FAIL_FAST_MODE)))
496496

497497
override def dataType: DataType = schema
498-
override def children: Seq[Expression] = child :: Nil
499498

500-
override def eval(input: InternalRow): Any = {
501-
val json = child.eval(input)
502-
if (json == null) return null
499+
override def nullSafeEval(json: Any): Any = {
503500
try parser.parse(json.toString).head catch {
504501
case _: SparkSQLJsonProcessingException => null
505502
}
@@ -512,7 +509,7 @@ case class JsonToStruct(schema: StructType, options: Map[String, String], child:
512509
* Converts a [[StructType]] to a json output string.
513510
*/
514511
case class StructToJson(options: Map[String, String], child: Expression)
515-
extends Expression with CodegenFallback with ExpectsInputTypes {
512+
extends UnaryExpression with CodegenFallback with ExpectsInputTypes {
516513
override def nullable: Boolean = true
517514

518515
@transient
@@ -523,7 +520,6 @@ case class StructToJson(options: Map[String, String], child: Expression)
523520
new JacksonGenerator(child.dataType.asInstanceOf[StructType], writer)
524521

525522
override def dataType: DataType = StringType
526-
override def children: Seq[Expression] = child :: Nil
527523

528524
override def checkInputDataTypes(): TypeCheckResult = {
529525
if (StructType.acceptsType(child.dataType)) {
@@ -540,8 +536,8 @@ case class StructToJson(options: Map[String, String], child: Expression)
540536
}
541537
}
542538

543-
override def eval(input: InternalRow): Any = {
544-
gen.write(child.eval(input).asInstanceOf[InternalRow])
539+
override def nullSafeEval(row: Any): Any = {
540+
gen.write(row.asInstanceOf[InternalRow])
545541
gen.flush()
546542
val json = writer.toString
547543
writer.reset()

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
2020
import org.apache.spark.SparkFunSuite
2121
import org.apache.spark.sql.catalyst.InternalRow
2222
import org.apache.spark.sql.catalyst.util.ParseModes
23-
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
23+
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
2424
import org.apache.spark.unsafe.types.UTF8String
2525

2626
class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -347,7 +347,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
347347
test("from_json null input column") {
348348
val schema = StructType(StructField("a", IntegerType) :: Nil)
349349
checkEvaluation(
350-
JsonToStruct(schema, Map.empty, Literal(null)),
350+
JsonToStruct(schema, Map.empty, Literal.create(null, StringType)),
351351
null
352352
)
353353
}
@@ -360,4 +360,13 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
360360
"""{"a":1}"""
361361
)
362362
}
363+
364+
test("to_json null input column") {
365+
val schema = StructType(StructField("a", IntegerType) :: Nil)
366+
val struct = Literal.create(null, schema)
367+
checkEvaluation(
368+
StructToJson(Map.empty, struct),
369+
null
370+
)
371+
}
363372
}

sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,4 +141,18 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext {
141141
assert(e.getMessage.contains(
142142
"Unable to convert column a of type calendarinterval to JSON."))
143143
}
144+
145+
test("roundtrip in to_json and from_json") {
146+
val dfOne = Seq(Some(Tuple1(Tuple1(1))), None).toDF("struct")
147+
val schemaOne = dfOne.schema(0).dataType.asInstanceOf[StructType]
148+
val readBackOne = dfOne.select(to_json($"struct").as("json"))
149+
.select(from_json($"json", schemaOne).as("struct"))
150+
checkAnswer(dfOne, readBackOne)
151+
152+
val dfTwo = Seq(Some("""{"a":1}"""), None).toDF("json")
153+
val schemaTwo = new StructType().add("a", IntegerType)
154+
val readBackTwo = dfTwo.select(from_json($"json", schemaTwo).as("struct"))
155+
.select(to_json($"struct").as("json"))
156+
checkAnswer(dfTwo, readBackTwo)
157+
}
144158
}

0 commit comments

Comments
 (0)