Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ case class JsonTuple(children: Seq[Expression])
* Converts an json input string to a [[StructType]] with the specified schema.
*/
case class JsonToStruct(schema: StructType, options: Map[String, String], child: Expression)
extends Expression with CodegenFallback with ExpectsInputTypes {
extends UnaryExpression with CodegenFallback with ExpectsInputTypes {
override def nullable: Boolean = true

@transient
Expand All @@ -495,11 +495,8 @@ case class JsonToStruct(schema: StructType, options: Map[String, String], child:
new JSONOptions(options ++ Map("mode" -> ParseModes.FAIL_FAST_MODE)))

override def dataType: DataType = schema
override def children: Seq[Expression] = child :: Nil

override def eval(input: InternalRow): Any = {
val json = child.eval(input)
if (json == null) return null
override def nullSafeEval(json: Any): Any = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to be clear: the old code already returned null when the input was null? right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, null safety for from_json was fixed in 6e27018 but I just decided to remove (a little bit) code duplication while matching to_json to from_json.

try parser.parse(json.toString).head catch {
case _: SparkSQLJsonProcessingException => null
}
Expand All @@ -512,7 +509,7 @@ case class JsonToStruct(schema: StructType, options: Map[String, String], child:
* Converts a [[StructType]] to a json output string.
*/
case class StructToJson(options: Map[String, String], child: Expression)
extends Expression with CodegenFallback with ExpectsInputTypes {
extends UnaryExpression with CodegenFallback with ExpectsInputTypes {
override def nullable: Boolean = true

@transient
Expand All @@ -523,7 +520,6 @@ case class StructToJson(options: Map[String, String], child: Expression)
new JacksonGenerator(child.dataType.asInstanceOf[StructType], writer)

override def dataType: DataType = StringType
override def children: Seq[Expression] = child :: Nil

override def checkInputDataTypes(): TypeCheckResult = {
if (StructType.acceptsType(child.dataType)) {
Expand All @@ -540,8 +536,8 @@ case class StructToJson(options: Map[String, String], child: Expression)
}
}

override def eval(input: InternalRow): Any = {
gen.write(child.eval(input).asInstanceOf[InternalRow])
override def nullSafeEval(row: Any): Any = {
gen.write(row.asInstanceOf[InternalRow])
gen.flush()
val json = writer.toString
writer.reset()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.ParseModes
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.unsafe.types.UTF8String

class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
Expand Down Expand Up @@ -347,7 +347,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
test("from_json null input column") {
val schema = StructType(StructField("a", IntegerType) :: Nil)
checkEvaluation(
JsonToStruct(schema, Map.empty, Literal(null)),
JsonToStruct(schema, Map.empty, Literal.create(null, StringType)),
null
)
}
Expand All @@ -360,4 +360,13 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
"""{"a":1}"""
)
}

test("to_json null input column") {
val schema = StructType(StructField("a", IntegerType) :: Nil)
val struct = Literal.create(null, schema)
checkEvaluation(
StructToJson(Map.empty, struct),
null
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -141,4 +141,18 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext {
assert(e.getMessage.contains(
"Unable to convert column a of type calendarinterval to JSON."))
}

test("roundtrip in to_json and from_json") {
val dfOne = Seq(Some(Tuple1(Tuple1(1))), None).toDF("struct")
val schemaOne = dfOne.schema(0).dataType.asInstanceOf[StructType]
val readBackOne = dfOne.select(to_json($"struct").as("json"))
.select(from_json($"json", schemaOne).as("struct"))
checkAnswer(dfOne, readBackOne)

val dfTwo = Seq(Some("""{"a":1}"""), None).toDF("json")
val schemaTwo = new StructType().add("a", IntegerType)
val readBackTwo = dfTwo.select(from_json($"json", schemaTwo).as("struct"))
.select(to_json($"struct").as("json"))
checkAnswer(dfTwo, readBackTwo)
}
}