Skip to content

Commit 57ed5a8

Browse files
viiryadongjoon-hyun
authored andcommitted
[SPARK-33007][SQL] Simplify named_struct + get struct field + from_json expression chain
### What changes were proposed in this pull request? This proposes to simplify named_struct + get struct field + from_json expression chain from `struct(from_json.col1, from_json.col2, from_json.col3...)` to `struct(from_json)`. ### Why are the changes needed? Simplify complex expression tree that could be produced by query optimization or user. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit test. Closes #29942 from viirya/SPARK-33007. Authored-by: Liang-Chi Hsieh <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 0b326d5 commit 57ed5a8

File tree

2 files changed

+103
-0
lines changed

2 files changed

+103
-0
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeJsonExprs.scala

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,46 @@ import org.apache.spark.sql.types.{ArrayType, StructType}
2828
* The optimization includes:
2929
* 1. JsonToStructs(StructsToJson(child)) => child.
3030
* 2. Prune unnecessary columns from GetStructField/GetArrayStructFields + JsonToStructs.
31+
* 3. CreateNamedStruct(JsonToStructs(json).col1, JsonToStructs(json).col2, ...) =>
32+
* If(IsNull(json), nullStruct, KnownNotNull(JsonToStructs(prunedSchema, ..., json)))
33+
* if JsonToStructs(json) is shared among all fields of CreateNamedStruct. `prunedSchema`
34+
* contains all accessed fields in original CreateNamedStruct.
3135
*/
3236
object OptimizeJsonExprs extends Rule[LogicalPlan] {
3337
override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
3438
case p => p.transformExpressions {
39+
40+
case c: CreateNamedStruct
41+
// If we create struct from various fields of the same `JsonToStructs`.
42+
if c.valExprs.forall { v =>
43+
v.isInstanceOf[GetStructField] &&
44+
v.asInstanceOf[GetStructField].child.isInstanceOf[JsonToStructs] &&
45+
v.children.head.semanticEquals(c.valExprs.head.children.head)
46+
} =>
47+
val jsonToStructs = c.valExprs.map(_.children.head)
48+
val sameFieldName = c.names.zip(c.valExprs).forall {
49+
case (name, valExpr: GetStructField) =>
50+
name.toString == valExpr.childSchema(valExpr.ordinal).name
51+
case _ => false
52+
}
53+
54+
// Although `CreateNamedStruct` allows duplicated field names, e.g. "a int, a int",
55+
// `JsonToStructs` does not support parsing json with duplicated field names.
56+
val duplicateFields = c.names.map(_.toString).distinct.length != c.names.length
57+
58+
// If we create struct from various fields of the same `JsonToStructs` and we don't
59+
// alias field names and there is no duplicated field in the struct.
60+
if (sameFieldName && !duplicateFields) {
61+
val fromJson = jsonToStructs.head.asInstanceOf[JsonToStructs].copy(schema = c.dataType)
62+
val nullFields = c.children.grouped(2).flatMap {
63+
case Seq(name, value) => Seq(name, Literal(null, value.dataType))
64+
}.toSeq
65+
66+
If(IsNull(fromJson.child), c.copy(children = nullFields), KnownNotNull(fromJson))
67+
} else {
68+
c
69+
}
70+
3571
case jsonToStructs @ JsonToStructs(_, options1,
3672
StructsToJson(options2, child, timeZoneId2), timeZoneId1)
3773
if options1.isEmpty && options2.isEmpty && timeZoneId1 == timeZoneId2 &&

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeJsonExprsSuite.scala

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,4 +199,71 @@ class OptimizeJsonExprsSuite extends PlanTest with ExpressionEvalHelper {
199199
JsonToStructs(prunedSchema2, options, 'json), field2, 0, 1, false).as("b")).analyze
200200
comparePlans(optimized2, expected2)
201201
}
202+
203+
test("SPARK-33007: simplify named_struct + from_json") {
204+
val options = Map.empty[String, String]
205+
val schema = StructType.fromDDL("a int, b int, c long, d string")
206+
207+
val prunedSchema1 = StructType.fromDDL("a int, b int")
208+
val nullStruct = namedStruct("a", Literal(null, IntegerType), "b", Literal(null, IntegerType))
209+
210+
val UTC_OPT = Option("UTC")
211+
val json: BoundReference = 'json.string.canBeNull.at(0)
212+
213+
assertEquivalent(
214+
testRelation2,
215+
namedStruct(
216+
"a", GetStructField(JsonToStructs(schema, options, json, UTC_OPT), 0),
217+
"b", GetStructField(JsonToStructs(schema, options, json, UTC_OPT), 1)).as("struct"),
218+
If(IsNull(json),
219+
nullStruct,
220+
KnownNotNull(JsonToStructs(prunedSchema1, options, json, UTC_OPT))).as("struct"))
221+
222+
val field1 = StructType.fromDDL("a int")
223+
val field2 = StructType.fromDDL("b int")
224+
225+
// Skip optimization if `namedStruct` aliases field name.
226+
assertEquivalent(
227+
testRelation2,
228+
namedStruct(
229+
"a1", GetStructField(JsonToStructs(schema, options, json, UTC_OPT), 0),
230+
"b", GetStructField(JsonToStructs(schema, options, json, UTC_OPT), 1)).as("struct"),
231+
namedStruct(
232+
"a1", GetStructField(JsonToStructs(field1, options, json, UTC_OPT), 0),
233+
"b", GetStructField(JsonToStructs(field2, options, json, UTC_OPT), 0)).as("struct"))
234+
235+
assertEquivalent(
236+
testRelation2,
237+
namedStruct(
238+
"a", GetStructField(JsonToStructs(schema, options, json, UTC_OPT), 0),
239+
"a", GetStructField(JsonToStructs(schema, options, json, UTC_OPT), 0)).as("struct"),
240+
namedStruct(
241+
"a", GetStructField(JsonToStructs(field1, options, json, UTC_OPT), 0),
242+
"a", GetStructField(JsonToStructs(field1, options, json, UTC_OPT), 0)).as("struct"))
243+
244+
val PST = getZoneId("-08:00")
245+
// Skip optimization if `JsonToStructs`s are not the same.
246+
assertEquivalent(
247+
testRelation2,
248+
namedStruct(
249+
"a", GetStructField(JsonToStructs(schema, options, json, UTC_OPT), 0),
250+
"b", GetStructField(JsonToStructs(schema, options, json, Option(PST.getId)), 1))
251+
.as("struct"),
252+
namedStruct(
253+
"a", GetStructField(JsonToStructs(field1, options, json, UTC_OPT), 0),
254+
"b", GetStructField(JsonToStructs(field2, options, json, Option(PST.getId)), 0))
255+
.as("struct"))
256+
}
257+
258+
private def assertEquivalent(relation: LocalRelation, e1: Expression, e2: Expression): Unit = {
259+
val plan = relation.select(e1).analyze
260+
val actual = Optimizer.execute(plan)
261+
val expected = relation.select(e2).analyze
262+
comparePlans(actual, expected)
263+
264+
Seq("""{"a":1, "b":2, "c": 123, "d": "test"}""", null).foreach(v => {
265+
val row = create_row(v)
266+
checkEvaluation(e1, e2.eval(row), row)
267+
})
268+
}
202269
}

0 commit comments

Comments
 (0)