diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala index 42abc0eafda7a..fabb5634ad10c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala @@ -371,8 +371,16 @@ object TableOutputResolver { resolveColumnsByPosition(tableName, Seq(param), Seq(fakeAttr), conf, addError, colPath) } if (res.length == 1) { - val func = LambdaFunction(res.head, Seq(param)) - Some(Alias(ArrayTransform(nullCheckedInput, func), expected.name)()) + if (res.head == param) { + // If the element type is the same, we can reuse the input array directly. + Some( + Alias(nullCheckedInput, expected.name)( + nonInheritableMetadataKeys = + Seq(CharVarcharUtils.CHAR_VARCHAR_TYPE_STRING_METADATA_KEY))) + } else { + val func = LambdaFunction(res.head, Seq(param)) + Some(Alias(ArrayTransform(nullCheckedInput, func), expected.name)()) + } } else { None } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala index d91a080d8fe89..21a049e914182 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala @@ -22,7 +22,7 @@ import java.util.Locale import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Cast, CreateNamedStruct, GetStructField, If, IsNull, LessThanOrEqual, Literal} +import org.apache.spark.sql.catalyst.expressions.{Alias, ArrayTransform, AttributeReference, Cast, CreateNamedStruct, GetStructField, If, IsNull, LessThanOrEqual, Literal} import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule @@ -304,6 +304,36 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { def byPosition(table: NamedRelation, query: LogicalPlan): LogicalPlan + test("SPARK-49352: Avoid redundant array transform for identical expression") { + def assertArrayField(fromType: ArrayType, toType: ArrayType, hasTransform: Boolean): Unit = { + val table = TestRelation(Seq($"a".int, $"arr".array(toType))) + val query = TestRelation(Seq($"arr".array(fromType), $"a".int)) + + val writePlan = byName(table, query).analyze + + assertResolved(writePlan) + checkAnalysis(writePlan, writePlan) + + val transform = writePlan.children.head.expressions.exists { e => + e.find { + case _: ArrayTransform => true + case _ => false + }.isDefined + } + if (hasTransform) { + assert(transform) + } else { + assert(!transform) + } + } + + assertArrayField(ArrayType(LongType), ArrayType(LongType), hasTransform = false) + assertArrayField( + ArrayType(new StructType().add("x", "int").add("y", "int")), + ArrayType(new StructType().add("y", "int").add("x", "byte")), + hasTransform = true) + } + test("SPARK-33136: output resolved on complex types for V2 write commands") { def assertTypeCompatibility(name: String, fromType: DataType, toType: DataType): Unit = { val table = TestRelation(StructType(Seq(StructField("a", toType))))