Skip to content

Commit c46ef5c

Browse files
committed
address review comments
1 parent 83fef40 commit c46ef5c

File tree

2 files changed

+16
-1
lines changed
  • sql/catalyst/src
    • main/scala/org/apache/spark/sql/catalyst/expressions
    • test/scala/org/apache/spark/sql/catalyst/expressions

2 files changed

+16
-1
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1042,7 +1042,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
10421042
}
10431043
val fieldsEvalCodes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) {
10441044
ctx.splitExpressions(fieldsEvalCode, "castStruct",
1045-
("InternalRow", ctx.INPUT_ROW) :: (rowClass, result) :: ("InternalRow", tmpRow) :: Nil)
1045+
("InternalRow", tmpRow) :: (rowClass, result) :: Nil)
10461046
} else {
10471047
fieldsEvalCode.mkString("\n")
10481048
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -830,6 +830,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
830830

831831
test("SPARK-22500: cast for struct should not generate codes beyond 64KB") {
832832
val N = 1000
833+
val M = 250
833834

834835
val from1 = new StructType(
835836
(1 to N).map(i => StructField(s"s$i", StringType)).toArray)
@@ -856,5 +857,19 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
856857
val input3 = Row.fromSeq((1 to N).map(i => Row(i)))
857858
val output3 = Row.fromSeq((1 to N).map(i => Row(i.toLong)))
858859
checkEvaluation(cast(Literal.create(input3, from3), to3), output3)
860+
861+
val fromInner = new StructType(
862+
(1 to M).map(i => StructField(s"s$i", DoubleType)).toArray)
863+
val toInner = new StructType(
864+
(1 to M).map(i => StructField(s"i$i", IntegerType)).toArray)
865+
val inputInner = Row.fromSeq((1 to M).map(i => i + 0.5))
866+
val outputInner = Row.fromSeq((1 to M))
867+
val fromOuter = new StructType(
868+
(1 to M).map(i => StructField(s"s$i", fromInner)).toArray)
869+
val toOuter = new StructType(
870+
(1 to M).map(i => StructField(s"s$i", toInner)).toArray)
871+
val inputOuter = Row.fromSeq((1 to M).map(_ => inputInner))
872+
val outputOuter = Row.fromSeq((1 to M).map(_ => outputInner))
873+
checkEvaluation(cast(Literal.create(inputOuter, fromOuter), toOuter), outputOuter)
859874
}
860875
}

0 commit comments

Comments
 (0)