Skip to content

Commit 4bbc343

Browse files
fqaiser94cloud-fan
andcommitted
[SPARK-31317][SQL] Add withField method to Column
### What changes were proposed in this pull request? Added a new `withField` method to the `Column` class. This method should allow users to add or replace a `StructField` in a `StructType` column (with very similar semantics to the `withColumn` method on `Dataset`). ### Why are the changes needed? Often Spark users have to work with deeply nested data e.g. to fix a data quality issue with an existing `StructField`. To do this with the existing Spark APIs, users have to rebuild the entire struct column. For example, let's say you have the following deeply nested data structure which has a data quality issue (`5` is missing): ``` import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ val data = spark.createDataFrame(sc.parallelize( Seq(Row(Row(Row(1, 2, 3), Row(Row(4, null, 6), Row(7, 8, 9), Row(10, 11, 12)), Row(13, 14, 15))))), StructType(Seq( StructField("a", StructType(Seq( StructField("a", StructType(Seq( StructField("a", IntegerType), StructField("b", IntegerType), StructField("c", IntegerType)))), StructField("b", StructType(Seq( StructField("a", StructType(Seq( StructField("a", IntegerType), StructField("b", IntegerType), StructField("c", IntegerType)))), StructField("b", StructType(Seq( StructField("a", IntegerType), StructField("b", IntegerType), StructField("c", IntegerType)))), StructField("c", StructType(Seq( StructField("a", IntegerType), StructField("b", IntegerType), StructField("c", IntegerType)))) ))), StructField("c", StructType(Seq( StructField("a", IntegerType), StructField("b", IntegerType), StructField("c", IntegerType)))) )))))).cache data.show(false) +---------------------------------+ |a | +---------------------------------+ |[[1, 2, 3], [[4,, 6], [7, 8, 9]]]| +---------------------------------+ ``` Currently, to replace the missing value users would have to do something like this: ``` val result = data.withColumn("a", struct( $"a.a", struct( struct( $"a.b.a.a", lit(5).as("b"), $"a.b.a.c" ).as("a"), $"a.b.b", $"a.b.c" ).as("b"), $"a.c" )) result.show(false) +---------------------------------------------------------------+ |a | +---------------------------------------------------------------+ |[[1, 2, 3], [[4, 5, 6], [7, 8, 9], [10, 11, 12]], [13, 14, 15]]| +---------------------------------------------------------------+ ``` As you can see above, with the existing methods users must call the `struct` function and list all fields, including fields they don't want to change. This is not ideal as: >this leads to complex, fragile code that cannot survive schema evolution. [SPARK-16483](https://issues.apache.org/jira/browse/SPARK-16483) In contrast, with the method added in this PR, a user could simply do something like this: ``` val result = data.withColumn("a", 'a.withField("b.a.b", lit(5))) result.show(false) +---------------------------------------------------------------+ |a | +---------------------------------------------------------------+ |[[1, 2, 3], [[4, 5, 6], [7, 8, 9], [10, 11, 12]], [13, 14, 15]]| +---------------------------------------------------------------+ ``` This is the first of maybe a few methods that could be added to the `Column` class to make it easier to manipulate nested data. Other methods under discussion in [SPARK-22231](https://issues.apache.org/jira/browse/SPARK-22231) include `drop` and `renameField`. However, these should be added in a separate PR. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? New unit tests were added. Jenkins must pass them. ### Related JIRAs: - https://issues.apache.org/jira/browse/SPARK-22231 - https://issues.apache.org/jira/browse/SPARK-16483 Closes #27066 from fqaiser94/SPARK-22231-withField. Lead-authored-by: [email protected] <[email protected]> Co-authored-by: fqaiser94 <[email protected]> Co-authored-by: Wenchen Fan <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 8d5c094 commit 4bbc343

File tree

8 files changed

+815
-3
lines changed

8 files changed

+815
-3
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,3 +539,61 @@ case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: E
539539

540540
override def prettyName: String = "str_to_map"
541541
}
542+
543+
/**
544+
* Adds/replaces field in struct by name.
545+
*/
546+
case class WithFields(
547+
structExpr: Expression,
548+
names: Seq[String],
549+
valExprs: Seq[Expression]) extends Unevaluable {
550+
551+
assert(names.length == valExprs.length)
552+
553+
override def checkInputDataTypes(): TypeCheckResult = {
554+
if (!structExpr.dataType.isInstanceOf[StructType]) {
555+
TypeCheckResult.TypeCheckFailure(
556+
"struct argument should be struct type, got: " + structExpr.dataType.catalogString)
557+
} else {
558+
TypeCheckResult.TypeCheckSuccess
559+
}
560+
}
561+
562+
override def children: Seq[Expression] = structExpr +: valExprs
563+
564+
override def dataType: StructType = evalExpr.dataType.asInstanceOf[StructType]
565+
566+
override def foldable: Boolean = structExpr.foldable && valExprs.forall(_.foldable)
567+
568+
override def nullable: Boolean = structExpr.nullable
569+
570+
override def prettyName: String = "with_fields"
571+
572+
lazy val evalExpr: Expression = {
573+
val existingExprs = structExpr.dataType.asInstanceOf[StructType].fieldNames.zipWithIndex.map {
574+
case (name, i) => (name, GetStructField(KnownNotNull(structExpr), i).asInstanceOf[Expression])
575+
}
576+
577+
val addOrReplaceExprs = names.zip(valExprs)
578+
579+
val resolver = SQLConf.get.resolver
580+
val newExprs = addOrReplaceExprs.foldLeft(existingExprs) {
581+
case (resultExprs, newExpr @ (newExprName, _)) =>
582+
if (resultExprs.exists(x => resolver(x._1, newExprName))) {
583+
resultExprs.map {
584+
case (name, _) if resolver(name, newExprName) => newExpr
585+
case x => x
586+
}
587+
} else {
588+
resultExprs :+ newExpr
589+
}
590+
}.flatMap { case (name, expr) => Seq(Literal(name), expr) }
591+
592+
val expr = CreateNamedStruct(newExprs)
593+
if (structExpr.nullable) {
594+
If(IsNull(structExpr), Literal(null, expr.dataType), expr)
595+
} else {
596+
expr
597+
}
598+
}
599+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,18 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] {
3939
// Remove redundant field extraction.
4040
case GetStructField(createNamedStruct: CreateNamedStruct, ordinal, _) =>
4141
createNamedStruct.valExprs(ordinal)
42-
42+
case GetStructField(w @ WithFields(struct, names, valExprs), ordinal, maybeName) =>
43+
val name = w.dataType(ordinal).name
44+
val matches = names.zip(valExprs).filter(_._1 == name)
45+
if (matches.nonEmpty) {
46+
// return last matching element as that is the final value for the field being extracted.
47+
// For example, if a user submits a query like this:
48+
// `$"struct_col".withField("b", lit(1)).withField("b", lit(2)).getField("b")`
49+
// we want to return `lit(2)` (and not `lit(1)`).
50+
matches.last._2
51+
} else {
52+
GetStructField(struct, ordinal, maybeName)
53+
}
4354
// Remove redundant array indexing.
4455
case GetArrayStructFields(CreateArray(elems, useStringTypeWhenEmpty), field, ordinal, _, _) =>
4556
// Instead of selecting the field on the entire array, select it from each member

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
107107
EliminateSerialization,
108108
RemoveRedundantAliases,
109109
RemoveNoopOperators,
110+
CombineWithFields,
110111
SimplifyExtractValueOps,
111112
CombineConcats) ++
112113
extendedOperatorOptimizationRules
@@ -207,7 +208,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
207208
CollapseProject,
208209
RemoveNoopOperators) :+
209210
// This batch must be executed after the `RewriteSubquery` batch, which creates joins.
210-
Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers)
211+
Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers) :+
212+
Batch("ReplaceWithFieldsExpression", Once, ReplaceWithFieldsExpression)
211213

212214
// remove any batches with no rules. this may happen when subclasses do not add optional rules.
213215
batches.filter(_.rules.nonEmpty)
@@ -240,7 +242,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
240242
PullupCorrelatedPredicates.ruleName ::
241243
RewriteCorrelatedScalarSubquery.ruleName ::
242244
RewritePredicateSubquery.ruleName ::
243-
NormalizeFloatingNumbers.ruleName :: Nil
245+
NormalizeFloatingNumbers.ruleName ::
246+
ReplaceWithFieldsExpression.ruleName :: Nil
244247

245248
/**
246249
* Optimize all the subqueries inside expression.
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.optimizer
19+
20+
import org.apache.spark.sql.catalyst.expressions.WithFields
21+
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
22+
import org.apache.spark.sql.catalyst.rules.Rule
23+
24+
25+
/**
26+
* Combines all adjacent [[WithFields]] expression into a single [[WithFields]] expression.
27+
*/
28+
object CombineWithFields extends Rule[LogicalPlan] {
29+
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
30+
case WithFields(WithFields(struct, names1, valExprs1), names2, valExprs2) =>
31+
WithFields(struct, names1 ++ names2, valExprs1 ++ valExprs2)
32+
}
33+
}
34+
35+
/**
36+
* Replaces [[WithFields]] expression with an evaluable expression.
37+
*/
38+
object ReplaceWithFieldsExpression extends Rule[LogicalPlan] {
39+
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
40+
case w: WithFields => w.evalExpr
41+
}
42+
}
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.optimizer
19+
20+
import org.apache.spark.sql.catalyst.dsl.expressions._
21+
import org.apache.spark.sql.catalyst.dsl.plans._
22+
import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, WithFields}
23+
import org.apache.spark.sql.catalyst.plans.PlanTest
24+
import org.apache.spark.sql.catalyst.plans.logical._
25+
import org.apache.spark.sql.catalyst.rules._
26+
27+
28+
class CombineWithFieldsSuite extends PlanTest {
29+
30+
object Optimize extends RuleExecutor[LogicalPlan] {
31+
val batches = Batch("CombineWithFields", FixedPoint(10), CombineWithFields) :: Nil
32+
}
33+
34+
private val testRelation = LocalRelation('a.struct('a1.int))
35+
36+
test("combines two WithFields") {
37+
val originalQuery = testRelation
38+
.select(Alias(
39+
WithFields(
40+
WithFields(
41+
'a,
42+
Seq("b1"),
43+
Seq(Literal(4))),
44+
Seq("c1"),
45+
Seq(Literal(5))), "out")())
46+
47+
val optimized = Optimize.execute(originalQuery.analyze)
48+
val correctAnswer = testRelation
49+
.select(Alias(WithFields('a, Seq("b1", "c1"), Seq(Literal(4), Literal(5))), "out")())
50+
.analyze
51+
52+
comparePlans(optimized, correctAnswer)
53+
}
54+
55+
test("combines three WithFields") {
56+
val originalQuery = testRelation
57+
.select(Alias(
58+
WithFields(
59+
WithFields(
60+
WithFields(
61+
'a,
62+
Seq("b1"),
63+
Seq(Literal(4))),
64+
Seq("c1"),
65+
Seq(Literal(5))),
66+
Seq("d1"),
67+
Seq(Literal(6))), "out")())
68+
69+
val optimized = Optimize.execute(originalQuery.analyze)
70+
val correctAnswer = testRelation
71+
.select(Alias(WithFields('a, Seq("b1", "c1", "d1"), Seq(4, 5, 6).map(Literal(_))), "out")())
72+
.analyze
73+
74+
comparePlans(optimized, correctAnswer)
75+
}
76+
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,4 +452,61 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
452452
checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](2, 1), BinaryType)), "2")
453453
checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](3, 4))), null)
454454
}
455+
456+
private val structAttr = 'struct1.struct('a.int)
457+
private val testStructRelation = LocalRelation(structAttr)
458+
459+
test("simplify GetStructField on WithFields that is not changing the attribute being extracted") {
460+
val query = testStructRelation.select(
461+
GetStructField(WithFields('struct1, Seq("b"), Seq(Literal(1))), 0, Some("a")) as "outerAtt")
462+
val expected = testStructRelation.select(GetStructField('struct1, 0, Some("a")) as "outerAtt")
463+
checkRule(query, expected)
464+
}
465+
466+
test("simplify GetStructField on WithFields that is changing the attribute being extracted") {
467+
val query = testStructRelation.select(
468+
GetStructField(WithFields('struct1, Seq("b"), Seq(Literal(1))), 1, Some("b")) as "outerAtt")
469+
val expected = testStructRelation.select(Literal(1) as "outerAtt")
470+
checkRule(query, expected)
471+
}
472+
473+
test(
474+
"simplify GetStructField on WithFields that is changing the attribute being extracted twice") {
475+
val query = testStructRelation
476+
.select(GetStructField(WithFields('struct1, Seq("b", "b"), Seq(Literal(1), Literal(2))), 1,
477+
Some("b")) as "outerAtt")
478+
val expected = testStructRelation.select(Literal(2) as "outerAtt")
479+
checkRule(query, expected)
480+
}
481+
482+
test("collapse multiple GetStructField on the same WithFields") {
483+
val query = testStructRelation
484+
.select(WithFields('struct1, Seq("b"), Seq(Literal(2))) as "struct2")
485+
.select(
486+
GetStructField('struct2, 0, Some("a")) as "struct1A",
487+
GetStructField('struct2, 1, Some("b")) as "struct1B")
488+
val expected = testStructRelation.select(
489+
GetStructField('struct1, 0, Some("a")) as "struct1A",
490+
Literal(2) as "struct1B")
491+
checkRule(query, expected)
492+
}
493+
494+
test("collapse multiple GetStructField on different WithFields") {
495+
val query = testStructRelation
496+
.select(
497+
WithFields('struct1, Seq("b"), Seq(Literal(2))) as "struct2",
498+
WithFields('struct1, Seq("b"), Seq(Literal(3))) as "struct3")
499+
.select(
500+
GetStructField('struct2, 0, Some("a")) as "struct2A",
501+
GetStructField('struct2, 1, Some("b")) as "struct2B",
502+
GetStructField('struct3, 0, Some("a")) as "struct3A",
503+
GetStructField('struct3, 1, Some("b")) as "struct3B")
504+
val expected = testStructRelation
505+
.select(
506+
GetStructField('struct1, 0, Some("a")) as "struct2A",
507+
Literal(2) as "struct2B",
508+
GetStructField('struct1, 0, Some("a")) as "struct3A",
509+
Literal(3) as "struct3B")
510+
checkRule(query, expected)
511+
}
455512
}

sql/core/src/main/scala/org/apache/spark/sql/Column.scala

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,72 @@ class Column(val expr: Expression) extends Logging {
871871
*/
872872
def getItem(key: Any): Column = withExpr { UnresolvedExtractValue(expr, Literal(key)) }
873873

874+
// scalastyle:off line.size.limit
875+
/**
876+
* An expression that adds/replaces field in `StructType` by name.
877+
*
878+
* {{{
879+
* val df = sql("SELECT named_struct('a', 1, 'b', 2) struct_col")
880+
* df.select($"struct_col".withField("c", lit(3)))
881+
* // result: {"a":1,"b":2,"c":3}
882+
*
883+
* val df = sql("SELECT named_struct('a', 1, 'b', 2) struct_col")
884+
* df.select($"struct_col".withField("b", lit(3)))
885+
* // result: {"a":1,"b":3}
886+
*
887+
* val df = sql("SELECT CAST(NULL AS struct<a:int,b:int>) struct_col")
888+
* df.select($"struct_col".withField("c", lit(3)))
889+
* // result: null of type struct<a:int,b:int,c:int>
890+
*
891+
* val df = sql("SELECT named_struct('a', 1, 'b', 2, 'b', 3) struct_col")
892+
* df.select($"struct_col".withField("b", lit(100)))
893+
* // result: {"a":1,"b":100,"b":100}
894+
*
895+
* val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) struct_col")
896+
* df.select($"struct_col".withField("a.c", lit(3)))
897+
* // result: {"a":{"a":1,"b":2,"c":3}}
898+
*
899+
* val df = sql("SELECT named_struct('a', named_struct('b', 1), 'a', named_struct('c', 2)) struct_col")
900+
* df.select($"struct_col".withField("a.c", lit(3)))
901+
* // result: org.apache.spark.sql.AnalysisException: Ambiguous reference to fields
902+
* }}}
903+
*
904+
* @group expr_ops
905+
* @since 3.1.0
906+
*/
907+
// scalastyle:on line.size.limit
908+
def withField(fieldName: String, col: Column): Column = withExpr {
909+
require(fieldName != null, "fieldName cannot be null")
910+
require(col != null, "col cannot be null")
911+
912+
val nameParts = if (fieldName.isEmpty) {
913+
fieldName :: Nil
914+
} else {
915+
CatalystSqlParser.parseMultipartIdentifier(fieldName)
916+
}
917+
withFieldHelper(expr, nameParts, Nil, col.expr)
918+
}
919+
920+
private def withFieldHelper(
921+
struct: Expression,
922+
namePartsRemaining: Seq[String],
923+
namePartsDone: Seq[String],
924+
value: Expression) : WithFields = {
925+
val name = namePartsRemaining.head
926+
if (namePartsRemaining.length == 1) {
927+
WithFields(struct, name :: Nil, value :: Nil)
928+
} else {
929+
val newNamesRemaining = namePartsRemaining.tail
930+
val newNamesDone = namePartsDone :+ name
931+
val newValue = withFieldHelper(
932+
struct = UnresolvedExtractValue(struct, Literal(name)),
933+
namePartsRemaining = newNamesRemaining,
934+
namePartsDone = newNamesDone,
935+
value = value)
936+
WithFields(struct, name :: Nil, newValue :: Nil)
937+
}
938+
}
939+
874940
/**
875941
* An expression that gets a field by name in a `StructType`.
876942
*

0 commit comments

Comments
 (0)