Skip to content

Commit 0acb7be

Browse files
committed
Add CreateNamedStruct in both DataFrame function API and FunctionRegistery
1 parent 1b0c8e6 commit 0acb7be

File tree

5 files changed

+169
-2
lines changed

5 files changed

+169
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ object FunctionRegistry {
9696
expression[Rand]("rand"),
9797
expression[Randn]("randn"),
9898
expression[CreateStruct]("struct"),
99+
expression[CreateNamedStruct]("named_struct"),
99100
expression[Sqrt]("sqrt"),
100101

101102
// math functions

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20+
import org.apache.spark.sql.catalyst
2021
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2122
import org.apache.spark.sql.catalyst.util.TypeUtils
2223
import org.apache.spark.sql.types._
24+
import org.apache.spark.unsafe.types.UTF8String
2325

2426
/**
2527
* Returns an Array containing the evaluation of all children expressions.
@@ -74,3 +76,44 @@ case class CreateStruct(children: Seq[Expression]) extends Expression {
7476

7577
override def prettyName: String = "struct"
7678
}
79+
80+
/**
81+
* Creates a struct with the given field names and values
82+
*
83+
* @param children Seq(name1, val1, name2, val2, ...)
84+
*/
85+
case class CreateNamedStruct(children: Seq[Expression]) extends Expression {
86+
assert(children.size % 2 == 0, "NamedStruct expects an even number of arguments.")
87+
88+
private val nameExprs = children.zipWithIndex.filter(_._2 % 2 == 0).map(_._1)
89+
private val valExprs = children.zipWithIndex.filter(_._2 % 2 == 1).map(_._1)
90+
91+
private lazy val names = nameExprs.map { case name =>
92+
name match {
93+
case NonNullLiteral(str, StringType) =>
94+
str.asInstanceOf[UTF8String].toString
95+
case _ =>
96+
throw new IllegalArgumentException("Expressions of odd index should be" +
97+
s" Literal(_, StringType), get ${name.dataType} instead")
98+
}
99+
}
100+
101+
override def foldable: Boolean = children.forall(_.foldable)
102+
103+
override lazy val resolved: Boolean = childrenResolved
104+
105+
override lazy val dataType: StructType = {
106+
assert(resolved,
107+
s"CreateStruct contains unresolvable children: ${children.filterNot(_.resolved)}.")
108+
val fields = names.zip(valExprs).map { case (name, valExpr) =>
109+
StructField(name, valExpr.dataType, valExpr.nullable, Metadata.empty)
110+
}
111+
StructType(fields)
112+
}
113+
114+
override def nullable: Boolean = false
115+
116+
override def eval(input: InternalRow): Any = {
117+
InternalRow(valExprs.map(_.eval(input)): _*)
118+
}
119+
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,38 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
100100
assert(getStructField(nullStruct, "a").nullable === true)
101101
}
102102

103+
test("complex type") {
104+
val row = create_row(
105+
"^Ba*n", // 0
106+
null.asInstanceOf[UTF8String], // 1
107+
create_row("aa", "bb"), // 2
108+
Map("aa" -> "bb"), // 3
109+
Seq("aa", "bb") // 4
110+
)
111+
112+
val typeS = StructType(
113+
StructField("a", StringType, true) :: StructField("b", StringType, true) :: Nil
114+
)
115+
val typeMap = MapType(StringType, StringType)
116+
val typeArray = ArrayType(StringType)
117+
118+
checkEvaluation(GetMapValue(BoundReference(3, typeMap, true),
119+
Literal("aa")), "bb", row)
120+
checkEvaluation(GetMapValue(Literal.create(null, typeMap), Literal("aa")), null, row)
121+
checkEvaluation(
122+
GetMapValue(Literal.create(null, typeMap), Literal.create(null, StringType)), null, row)
123+
checkEvaluation(GetMapValue(BoundReference(3, typeMap, true),
124+
Literal.create(null, StringType)), null, row)
125+
126+
checkEvaluation(GetArrayItem(BoundReference(4, typeArray, true),
127+
Literal(1)), "bb", row)
128+
checkEvaluation(GetArrayItem(Literal.create(null, typeArray), Literal(1)), null, row)
129+
checkEvaluation(
130+
GetArrayItem(Literal.create(null, typeArray), Literal.create(null, IntegerType)), null, row)
131+
checkEvaluation(GetArrayItem(BoundReference(4, typeArray, true),
132+
Literal.create(null, IntegerType)), null, row)
133+
}
134+
103135
test("GetArrayStructFields") {
104136
val typeAS = ArrayType(StructType(StructField("a", IntegerType) :: Nil))
105137
val arrayStruct = Literal.create(Seq(create_row(1)), typeAS)
@@ -119,11 +151,29 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
119151

120152
test("CreateStruct") {
121153
val row = create_row(1, 2, 3)
122-
val c1 = 'a.int.at(0).as("a")
123-
val c3 = 'c.int.at(2).as("c")
154+
val c1 = 'a.int.at(0)
155+
val c3 = 'c.int.at(2)
124156
checkEvaluation(CreateStruct(Seq(c1, c3)), create_row(1, 3), row)
125157
}
126158

159+
test("CreateNamedStruct") {
160+
val row = InternalRow(1, 2, 3)
161+
val c1 = 'a.int.at(0)
162+
val c3 = 'c.int.at(2)
163+
checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", c3)), InternalRow(1, 3), row)
164+
}
165+
166+
test("CreateNamedStruct with literal field") {
167+
val row = InternalRow(1, 2, 3)
168+
val c1 = 'a.int.at(0)
169+
checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", "y")), InternalRow(1, "y"), row)
170+
}
171+
172+
test("CreateNamedStruct from all literal fields") {
173+
checkEvaluation(
174+
CreateNamedStruct(Seq("a", "x", "b", 2.0)), InternalRow("x", 2.0), InternalRow.empty)
175+
}
176+
127177
test("test dsl for complex type") {
128178
def quickResolve(u: UnresolvedExtractValue): Expression = {
129179
ExtractValue(u.child, u.extraction, _ == _)

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -762,6 +762,21 @@ object functions {
762762
struct((colName +: colNames).map(col) : _*)
763763
}
764764

765+
/**
766+
* Creates a new struct column with given field names and columns.
767+
* The input columns should be of length 2*n and follow (name1, col1, name2, col2),
768+
* name* should be String Literal
769+
*
770+
* @group normal_funcs
771+
* @since 1.5.0
772+
*/
773+
@scala.annotation.varargs
774+
def named_struct(cols: Column*): Column = {
775+
require(cols.length % 2 == 0,
776+
s"named_struct expects an even number of arguments.")
777+
CreateNamedStruct(cols.map(_.expr))
778+
}
779+
765780
/**
766781
* Converts a string expression to upper case.
767782
*

sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,64 @@ class DataFrameFunctionsSuite extends QueryTest {
8585
}
8686
}
8787

88+
test("named_struct with column expression") {
89+
val df = Seq((1, "str1"), (2, "str2")).toDF("a", "b")
90+
val row = df.select(
91+
named_struct(lit("x"), (col("a") * 2), lit("y"), col("b"))).take(2)
92+
93+
val expectedType = StructType(Seq(
94+
StructField("x", IntegerType, nullable = false),
95+
StructField("y", StringType)
96+
))
97+
98+
assert(row(0).schema(0).dataType === expectedType)
99+
assert(row(0).getAs[Row](0) === Row(2, "str1"))
100+
assert(row(1).getAs[Row](0) === Row(4, "str2"))
101+
}
102+
103+
test("named_struct with literal columns") {
104+
val df = Seq((1, "str1"), (2, "str2")).toDF("a", "b")
105+
val row = df.select(
106+
named_struct(lit("x"), (col("a") * 2), lit("y"), lit(5.0))).take(2)
107+
108+
val expectedType = StructType(Seq(
109+
StructField("x", IntegerType, nullable = false),
110+
StructField("y", DoubleType, nullable = false)
111+
))
112+
113+
assert(row(0).schema(0).dataType === expectedType)
114+
assert(row(0).getAs[Row](0) === Row(2, 5.0))
115+
assert(row(1).getAs[Row](0) === Row(4, 5.0))
116+
}
117+
118+
test("named_struct with all literal columns") {
119+
val df = Seq((1, "str1"), (2, "str2")).toDF("a", "b")
120+
val row = df.select(
121+
named_struct(lit("x"), lit("v"), lit("y"), lit(5.0))).take(2)
122+
123+
val expectedType = StructType(Seq(
124+
StructField("x", StringType, nullable = false),
125+
StructField("y", DoubleType, nullable = false)
126+
))
127+
128+
assert(row(0).schema(0).dataType === expectedType)
129+
assert(row(0).getAs[Row](0) === Row("v", 5.0))
130+
assert(row(1).getAs[Row](0) === Row("v", 5.0))
131+
}
132+
133+
test("named_struct with odd arguments") {
134+
intercept[IllegalArgumentException] {
135+
named_struct(col("x"))
136+
}
137+
}
138+
139+
test("named_struct with non string literal names") {
140+
val df = Seq((1, "str")).toDF("a", "b")
141+
intercept[IllegalArgumentException] {
142+
df.select(named_struct(lit(1), (col("a") * 2), lit("y"), lit(5.0)))
143+
}
144+
}
145+
88146
test("constant functions") {
89147
checkAnswer(
90148
ctx.sql("SELECT E()"),

0 commit comments

Comments
 (0)