Skip to content

Commit 4bd75ad

Browse files
committed
loosen struct method in functions.scala to take Expression children
1 parent 0acb7be commit 4bd75ad

File tree

2 files changed

+42
-7
lines changed

2 files changed

+42
-7
lines changed

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -747,9 +747,7 @@ object functions {
747747
*/
748748
@scala.annotation.varargs
749749
def struct(cols: Column*): Column = {
750-
require(cols.forall(_.expr.isInstanceOf[NamedExpression]),
751-
s"struct input columns must all be named or aliased ($cols)")
752-
CreateStruct(cols.map(_.expr.asInstanceOf[NamedExpression]))
750+
CreateStruct(cols.map(_.expr))
753751
}
754752

755753
/**

sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,49 @@ class DataFrameFunctionsSuite extends QueryTest {
7979
assert(row.getAs[Row](0) === Row(2, "str"))
8080
}
8181

82-
test("struct: must use named column expression") {
83-
intercept[IllegalArgumentException] {
84-
struct(col("a") * 2)
85-
}
82+
test("struct with column expression to be automatically named") {
83+
val df = Seq((1, "str")).toDF("a", "b")
84+
val row = df.select(struct((col("a") * 2), col("b"))).first()
85+
86+
val expectedType = StructType(Seq(
87+
StructField("col1", IntegerType, nullable = false),
88+
StructField("b", StringType)
89+
))
90+
assert(row.schema(0).dataType === expectedType)
91+
assert(row.getAs[Row](0) === Row(2, "str"))
92+
}
93+
94+
test("struct with literal columns") {
95+
val df = Seq((1, "str1"), (2, "str2")).toDF("a", "b")
96+
val row = df.select(
97+
struct((col("a") * 2), lit(5.0))).take(2)
98+
99+
val expectedType = StructType(Seq(
100+
StructField("col1", IntegerType, nullable = false),
101+
StructField("col2", DoubleType, nullable = false)
102+
))
103+
104+
assert(row(0).schema(0).dataType === expectedType)
105+
assert(row(0).getAs[Row](0) === Row(2, 5.0))
106+
assert(row(1).getAs[Row](0) === Row(4, 5.0))
107+
}
108+
109+
test("struct with all literal columns") {
110+
val df = Seq((1, "str1"), (2, "str2")).toDF("a", "b")
111+
val row = df.select(
112+
struct(lit("v"), lit(5.0))).take(2)
113+
114+
val expectedType = StructType(Seq(
115+
StructField("col1", StringType, nullable = false),
116+
StructField("col2", DoubleType, nullable = false)
117+
))
118+
119+
assert(row(0).schema(0).dataType === expectedType)
120+
assert(row(0).getAs[Row](0) === Row("v", 5.0))
121+
assert(row(1).getAs[Row](0) === Row("v", 5.0))
86122
}
87123

124+
88125
test("named_struct with column expression") {
89126
val df = Seq((1, "str1"), (2, "str2")).toDF("a", "b")
90127
val row = df.select(

0 commit comments

Comments
 (0)