|
17 | 17 |
|
18 | 18 | package org.apache.spark.sql |
19 | 19 |
|
| 20 | +import scala.collection.mutable |
| 21 | + |
20 | 22 | import org.apache.spark.sql.catalyst.DefinedByConstructorParams |
| 23 | +import org.apache.spark.sql.catalyst.expressions.{Expression, GenericRowWithSchema} |
| 24 | +import org.apache.spark.sql.catalyst.expressions.objects.MapObjects |
21 | 25 | import org.apache.spark.sql.functions._ |
| 26 | +import org.apache.spark.sql.internal.SQLConf |
22 | 27 | import org.apache.spark.sql.test.SharedSparkSession |
| 28 | +import org.apache.spark.sql.types.ArrayType |
23 | 29 |
|
24 | 30 | /** |
25 | 31 | * A test suite to test DataFrame/SQL functionalities with complex types (i.e. array, struct, map). |
@@ -64,6 +70,33 @@ class DataFrameComplexTypeSuite extends QueryTest with SharedSparkSession { |
64 | 70 | val ds100_5 = Seq(S100_5()).toDS() |
65 | 71 | ds100_5.rdd.count |
66 | 72 | } |
| 73 | + |
| 74 | + test("SPARK-29503 nest unsafe struct inside safe array") { |
| 75 | + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { |
| 76 | + val exampleDS = spark.sparkContext.parallelize(Seq(Seq(1, 2, 3))).toDF("items") |
| 77 | + |
| 78 | + // items: Seq[Int] => items.map { item => Seq(Struct(item)) } |
| 79 | + val result = exampleDS.select( |
| 80 | + new Column(MapObjects( |
| 81 | + (item: Expression) => array(struct(new Column(item))).expr, |
| 82 | + $"items".expr, |
| 83 | + exampleDS.schema("items").dataType.asInstanceOf[ArrayType].elementType |
| 84 | + )) as "items" |
| 85 | + ).collect() |
| 86 | + |
| 87 | + def getValueInsideDepth(result: Row, index: Int): Int = { |
| 88 | + // expected output: |
| 89 | + // WrappedArray([WrappedArray(WrappedArray([1]), WrappedArray([2]), WrappedArray([3]))]) |
| 90 | + result.getSeq[mutable.WrappedArray[_]](0)(index)(0) |
| 91 | + .asInstanceOf[GenericRowWithSchema].getInt(0) |
| 92 | + } |
| 93 | + |
| 94 | + assert(result.size === 1) |
| 95 | + assert(getValueInsideDepth(result.head, 0) === 1) |
| 96 | + assert(getValueInsideDepth(result.head, 1) === 2) |
| 97 | + assert(getValueInsideDepth(result.head, 2) === 3) |
| 98 | + } |
| 99 | + } |
67 | 100 | } |
68 | 101 |
|
69 | 102 | class S100( |
|
0 commit comments