Skip to content

Commit 52302a8

Browse files
yjshenmarmbrus
authored andcommitted
[SPARK-8407] [SQL] complex type constructors: struct and named_struct
This is a follow up of [SPARK-8283](https://issues.apache.org/jira/browse/SPARK-8283) ([PR-6828](#6828)), to support both `struct` and `named_struct` in Spark SQL. After [#6725](#6828), the semantic of [`CreateStruct`](https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala#L56) methods have changed a little and do not limited to cols of `NamedExpressions`, it will name non-NamedExpression fields following the hive convention, col1, col2 ... This PR would both loosen [`struct`](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/functions.scala#L723) to take children of `Expression` type and add `named_struct` support. Author: Yijie Shen <[email protected]> Closes #6874 from yijieshen/SPARK-8283 and squashes the following commits: 4cd3375 [Yijie Shen] change struct documentation d599d0b [Yijie Shen] rebase code 9a7039e [Yijie Shen] fix reviews and regenerate golden answers b487354 [Yijie Shen] replace assert using checkAnswer f07e114 [Yijie Shen] tiny fix 9613be9 [Yijie Shen] review fix 7fef712 [Yijie Shen] Fix checkInputTypes' implementation using foldable and nullable 60812a7 [Yijie Shen] Fix type check 828d694 [Yijie Shen] remove unnecessary resolved assertion inside dataType method fd3cd8e [Yijie Shen] remove type check from eval 7a71255 [Yijie Shen] tiny fix ccbbd86 [Yijie Shen] Fix reviews 47da332 [Yijie Shen] remove nameStruct API from DataFrame 917e680 [Yijie Shen] Fix reviews 4bd75ad [Yijie Shen] loosen struct method in functions.scala to take Expression children 0acb7be [Yijie Shen] Add CreateNamedStruct in both DataFrame function API and FunctionRegistery
1 parent afa021e commit 52302a8

File tree

9 files changed

+126
-13
lines changed

9 files changed

+126
-13
lines changed

python/pyspark/sql/functions.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,6 @@ def struct(*cols):
467467
"""Creates a new struct column.
468468
469469
:param cols: list of column names (string) or list of :class:`Column` expressions
470-
that are named or aliased.
471470
472471
>>> df.select(struct('age', 'name').alias("struct")).collect()
473472
[Row(struct=Row(age=2, name=u'Alice')), Row(struct=Row(age=5, name=u'Bob'))]

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ object FunctionRegistry {
9696
expression[Rand]("rand"),
9797
expression[Randn]("randn"),
9898
expression[CreateStruct]("struct"),
99+
expression[CreateNamedStruct]("named_struct"),
99100
expression[Sqrt]("sqrt"),
100101

101102
// math functions

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,12 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20+
import org.apache.spark.sql.catalyst
2021
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2122
import org.apache.spark.sql.catalyst.util.TypeUtils
23+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
2224
import org.apache.spark.sql.types._
25+
import org.apache.spark.unsafe.types.UTF8String
2326

2427
/**
2528
* Returns an Array containing the evaluation of all children expressions.
@@ -54,6 +57,8 @@ case class CreateStruct(children: Seq[Expression]) extends Expression {
5457

5558
override def foldable: Boolean = children.forall(_.foldable)
5659

60+
override lazy val resolved: Boolean = childrenResolved
61+
5762
override lazy val dataType: StructType = {
5863
val fields = children.zipWithIndex.map { case (child, idx) =>
5964
child match {
@@ -74,3 +79,47 @@ case class CreateStruct(children: Seq[Expression]) extends Expression {
7479

7580
override def prettyName: String = "struct"
7681
}
82+
83+
/**
84+
* Creates a struct with the given field names and values
85+
*
86+
* @param children Seq(name1, val1, name2, val2, ...)
87+
*/
88+
case class CreateNamedStruct(children: Seq[Expression]) extends Expression {
89+
90+
private lazy val (nameExprs, valExprs) =
91+
children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip
92+
93+
private lazy val names = nameExprs.map(_.eval(EmptyRow).toString)
94+
95+
override lazy val dataType: StructType = {
96+
val fields = names.zip(valExprs).map { case (name, valExpr) =>
97+
StructField(name, valExpr.dataType, valExpr.nullable, Metadata.empty)
98+
}
99+
StructType(fields)
100+
}
101+
102+
override def foldable: Boolean = valExprs.forall(_.foldable)
103+
104+
override def nullable: Boolean = false
105+
106+
override def checkInputDataTypes(): TypeCheckResult = {
107+
if (children.size % 2 != 0) {
108+
TypeCheckResult.TypeCheckFailure("CreateNamedStruct expects an even number of arguments.")
109+
} else {
110+
val invalidNames =
111+
nameExprs.filterNot(e => e.foldable && e.dataType == StringType && !nullable)
112+
if (invalidNames.size != 0) {
113+
TypeCheckResult.TypeCheckFailure(
114+
s"Odd position only allow foldable and not-null StringType expressions, got :" +
115+
s" ${invalidNames.mkString(",")}")
116+
} else {
117+
TypeCheckResult.TypeCheckSuccess
118+
}
119+
}
120+
}
121+
122+
override def eval(input: InternalRow): Any = {
123+
InternalRow(valExprs.map(_.eval(input)): _*)
124+
}
125+
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,4 +160,15 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
160160
assertError(Explode('intField),
161161
"input to function explode should be array or map type")
162162
}
163+
164+
test("check types for CreateNamedStruct") {
165+
assertError(
166+
CreateNamedStruct(Seq("a", "b", 2.0)), "even number of arguments")
167+
assertError(
168+
CreateNamedStruct(Seq(1, "a", "b", 2.0)),
169+
"Odd position only allow foldable and not-null StringType expressions")
170+
assertError(
171+
CreateNamedStruct(Seq('a.string.at(0), "a", "b", 2.0)),
172+
"Odd position only allow foldable and not-null StringType expressions")
173+
}
163174
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20+
import org.scalatest.exceptions.TestFailedException
21+
2022
import org.apache.spark.SparkFunSuite
2123
import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue
2224
import org.apache.spark.sql.catalyst.dsl.expressions._
@@ -119,11 +121,29 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
119121

120122
test("CreateStruct") {
121123
val row = create_row(1, 2, 3)
122-
val c1 = 'a.int.at(0).as("a")
123-
val c3 = 'c.int.at(2).as("c")
124+
val c1 = 'a.int.at(0)
125+
val c3 = 'c.int.at(2)
124126
checkEvaluation(CreateStruct(Seq(c1, c3)), create_row(1, 3), row)
125127
}
126128

129+
test("CreateNamedStruct") {
130+
val row = InternalRow(1, 2, 3)
131+
val c1 = 'a.int.at(0)
132+
val c3 = 'c.int.at(2)
133+
checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", c3)), InternalRow(1, 3), row)
134+
}
135+
136+
test("CreateNamedStruct with literal field") {
137+
val row = InternalRow(1, 2, 3)
138+
val c1 = 'a.int.at(0)
139+
checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", "y")), InternalRow(1, "y"), row)
140+
}
141+
142+
test("CreateNamedStruct from all literal fields") {
143+
checkEvaluation(
144+
CreateNamedStruct(Seq("a", "x", "b", 2.0)), InternalRow("x", 2.0), InternalRow.empty)
145+
}
146+
127147
test("test dsl for complex type") {
128148
def quickResolve(u: UnresolvedExtractValue): Expression = {
129149
ExtractValue(u.child, u.extraction, _ == _)

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -739,17 +739,18 @@ object functions {
739739
def sqrt(colName: String): Column = sqrt(Column(colName))
740740

741741
/**
742-
* Creates a new struct column. The input column must be a column in a [[DataFrame]], or
743-
* a derived column expression that is named (i.e. aliased).
742+
* Creates a new struct column.
743+
* If the input column is a column in a [[DataFrame]], or a derived column expression
744+
* that is named (i.e. aliased), its name would be remained as the StructField's name,
745+
* otherwise, the newly generated StructField's name would be auto generated as col${index + 1},
746+
* i.e. col1, col2, col3, ...
744747
*
745748
* @group normal_funcs
746749
* @since 1.4.0
747750
*/
748751
@scala.annotation.varargs
749752
def struct(cols: Column*): Column = {
750-
require(cols.forall(_.expr.isInstanceOf[NamedExpression]),
751-
s"struct input columns must all be named or aliased ($cols)")
752-
CreateStruct(cols.map(_.expr.asInstanceOf[NamedExpression]))
753+
CreateStruct(cols.map(_.expr))
753754
}
754755

755756
/**

sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,42 @@ class DataFrameFunctionsSuite extends QueryTest {
7979
assert(row.getAs[Row](0) === Row(2, "str"))
8080
}
8181

82-
test("struct: must use named column expression") {
83-
intercept[IllegalArgumentException] {
84-
struct(col("a") * 2)
85-
}
82+
test("struct with column expression to be automatically named") {
83+
val df = Seq((1, "str")).toDF("a", "b")
84+
val result = df.select(struct((col("a") * 2), col("b")))
85+
86+
val expectedType = StructType(Seq(
87+
StructField("col1", IntegerType, nullable = false),
88+
StructField("b", StringType)
89+
))
90+
assert(result.first.schema(0).dataType === expectedType)
91+
checkAnswer(result, Row(Row(2, "str")))
92+
}
93+
94+
test("struct with literal columns") {
95+
val df = Seq((1, "str1"), (2, "str2")).toDF("a", "b")
96+
val result = df.select(struct((col("a") * 2), lit(5.0)))
97+
98+
val expectedType = StructType(Seq(
99+
StructField("col1", IntegerType, nullable = false),
100+
StructField("col2", DoubleType, nullable = false)
101+
))
102+
103+
assert(result.first.schema(0).dataType === expectedType)
104+
checkAnswer(result, Seq(Row(Row(2, 5.0)), Row(Row(4, 5.0))))
105+
}
106+
107+
test("struct with all literal columns") {
108+
val df = Seq((1, "str1"), (2, "str2")).toDF("a", "b")
109+
val result = df.select(struct(lit("v"), lit(5.0)))
110+
111+
val expectedType = StructType(Seq(
112+
StructField("col1", StringType, nullable = false),
113+
StructField("col2", DoubleType, nullable = false)
114+
))
115+
116+
assert(result.first.schema(0).dataType === expectedType)
117+
checkAnswer(result, Seq(Row(Row("v", 5.0)), Row(Row("v", 5.0))))
86118
}
87119

88120
test("constant functions") {

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
132132
lower("AA"), "10",
133133
repeat(lower("AA"), 3), "11",
134134
lower(repeat("AA", 3)), "12",
135-
printf("Bb%d", 12), "13",
135+
printf("bb%d", 12), "13",
136136
repeat(printf("s%d", 14), 2), "14") FROM src LIMIT 1""")
137137

138138
createQueryTest("NaN to Decimal",

0 commit comments

Comments
 (0)