Skip to content

Commit 62f46f6

Browse files
author
Bogdan Raducanu
committed
Added outer_explode, outer_posexplode, outer_inline
1 parent 12c8c21 commit 62f46f6

File tree

8 files changed

+134
-24
lines changed

8 files changed

+134
-24
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1621,9 +1621,11 @@ class Analyzer(
16211621

16221622
/** Extracts a [[Generator]] expression and any names assigned by aliases to their output. */
16231623
private object AliasedGenerator {
1624-
def unapply(e: Expression): Option[(Generator, Seq[String])] = e match {
1625-
case Alias(g: Generator, name) if g.resolved => Some((g, name :: Nil))
1626-
case MultiAlias(g: Generator, names) if g.resolved => Some(g, names)
1624+
def unapply(e: Expression): Option[(Generator, Seq[String], Boolean)] = e match {
1625+
case Alias(GeneratorOuter(g: Generator), name) if g.resolved => Some((g, name :: Nil, true))
1626+
case MultiAlias(GeneratorOuter(g: Generator), names) if g.resolved => Some(g, names, true)
1627+
case Alias(g: Generator, name) if g.resolved => Some((g, name :: Nil, false))
1628+
case MultiAlias(g: Generator, names) if g.resolved => Some(g, names, false)
16271629
case _ => None
16281630
}
16291631
}
@@ -1644,7 +1646,7 @@ class Analyzer(
16441646
var resolvedGenerator: Generate = null
16451647

16461648
val newProjectList = projectList.flatMap {
1647-
case AliasedGenerator(generator, names) if generator.childrenResolved =>
1649+
case AliasedGenerator(generator, names, outer) if generator.childrenResolved =>
16481650
// It's a sanity check, this should not happen as the previous case will throw
16491651
// exception earlier.
16501652
assert(resolvedGenerator == null, "More than one generator found in SELECT.")
@@ -1653,7 +1655,7 @@ class Analyzer(
16531655
Generate(
16541656
generator,
16551657
join = projectList.size > 1, // Only join if there are other expressions in SELECT.
1656-
outer = false,
1658+
outer = outer,
16571659
qualifier = None,
16581660
generatorOutput = ResolveGenerate.makeGeneratorOutput(generator, names),
16591661
child)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,9 @@ object FunctionRegistry {
175175
expression[NullIf]("nullif"),
176176
expression[Nvl]("nvl"),
177177
expression[Nvl2]("nvl2"),
178+
expression[OuterExplode]("outer_explode"),
179+
expression[OuterInline]("outer_inline"),
180+
expression[OuterPosExplode]("outer_posexplode"),
178181
expression[PosExplode]("posexplode"),
179182
expression[Rand]("rand"),
180183
expression[Randn]("randn"),

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,17 @@ case class Stack(children: Seq[Expression]) extends Generator {
204204
}
205205
}
206206

207+
case class GeneratorOuter(child: Generator) extends UnaryExpression
208+
with Generator {
209+
210+
final override def eval(input: InternalRow = null): TraversableOnce[InternalRow] =
211+
throw new UnsupportedOperationException(s"Cannot evaluate expression: $this")
212+
213+
final override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
214+
throw new UnsupportedOperationException(s"Cannot evaluate expression: $this")
215+
216+
override def elementSchema: StructType = child.elementSchema
217+
}
207218
/**
208219
* A base class for [[Explode]] and [[PosExplode]].
209220
*/
@@ -233,11 +244,11 @@ abstract class ExplodeBase extends UnaryExpression with CollectionGenerator with
233244
if (position) {
234245
new StructType()
235246
.add("pos", IntegerType, nullable = false)
236-
.add("key", kt, nullable = false)
247+
.add("key", kt, nullable = true)
237248
.add("value", vt, valueContainsNull)
238249
} else {
239250
new StructType()
240-
.add("key", kt, nullable = false)
251+
.add("key", kt, nullable = true)
241252
.add("value", vt, valueContainsNull)
242253
}
243254
}
@@ -300,7 +311,7 @@ abstract class ExplodeBase extends UnaryExpression with CollectionGenerator with
300311
case class Explode(child: Expression) extends ExplodeBase {
301312
override val position: Boolean = false
302313
}
303-
314+
class OuterExplode(child: Expression) extends GeneratorOuter(Explode(child))
304315
/**
305316
* Given an input array produces a sequence of rows for each position and value in the array.
306317
*
@@ -323,7 +334,7 @@ case class Explode(child: Expression) extends ExplodeBase {
323334
case class PosExplode(child: Expression) extends ExplodeBase {
324335
override val position = true
325336
}
326-
337+
class OuterPosExplode(child: Expression) extends GeneratorOuter(PosExplode(child))
327338
/**
328339
* Explodes an array of structs into a table.
329340
*/
@@ -369,3 +380,5 @@ case class Inline(child: Expression) extends UnaryExpression with CollectionGene
369380
child.genCode(ctx)
370381
}
371382
}
383+
384+
class OuterInline(child: Expression) extends GeneratorOuter(Inline(child))

sql/core/src/main/scala/org/apache/spark/sql/Column.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -166,10 +166,7 @@ class Column(val expr: Expression) extends Logging {
166166

167167
// Leave an unaliased generator with an empty list of names since the analyzer will generate
168168
// the correct defaults after the nested expression's type has been resolved.
169-
case explode: Explode => MultiAlias(explode, Nil)
170-
case explode: PosExplode => MultiAlias(explode, Nil)
171-
172-
case jt: JsonTuple => MultiAlias(jt, Nil)
169+
case g: Generator => MultiAlias(g, Nil)
173170

174171
case func: UnresolvedFunction => UnresolvedAlias(func, Some(Column.generateAlias))
175172

sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,20 @@ case class GenerateExec(
160160

161161
// Generate looping variables.
162162
val index = ctx.freshName("index")
163+
val numElements = ctx.freshName("numElements")
164+
165+
// In case of outer=true we need to make sure the loop is executed at-least once when the
166+
// array/map contains no input.
167+
// generateOuter is an int. it is set to 1 iff outer is true and the input is empty or null.
168+
val generateOuter = ctx.freshName("generateOuter")
169+
val isOuter = if (outer) {
170+
"true"
171+
} else {
172+
"false"
173+
}
163174

164175
// Add a check if the generate outer flag is true.
165-
val checks = optionalCode(outer, data.isNull)
176+
val checks = optionalCode(outer, s"($generateOuter == 1)")
166177

167178
// Add position
168179
val position = if (e.position) {
@@ -199,21 +210,13 @@ case class GenerateExec(
199210
(initArrayData, "", values)
200211
}
201212

202-
// In case of outer=true we need to make sure the loop is executed at-least once when the
203-
// array/map contains no input. We do this by setting the looping index to -1 if there is no
204-
// input, evaluation of the array is prevented by a check in the accessor code.
205-
val numElements = ctx.freshName("numElements")
206-
val init = if (outer) {
207-
s"$numElements == 0 ? -1 : 0"
208-
} else {
209-
"0"
210-
}
211213
val numOutput = metricTerm(ctx, "numOutputRows")
212214
s"""
213215
|${data.code}
214216
|$initMapData
215217
|int $numElements = ${data.isNull} ? 0 : ${data.value}.numElements();
216-
|for (int $index = $init; $index < $numElements; $index++) {
218+
|int $generateOuter = ($numElements == 0 && $isOuter) ? 1 : 0;
219+
|for (int $index = 0; $index < $numElements + $generateOuter; $index++) {
217220
| $numOutput.add(1);
218221
| $updateRowData
219222
| ${consume(ctx, input ++ position ++ values)}

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2870,6 +2870,15 @@ object functions {
28702870
*/
28712871
def explode(e: Column): Column = withExpr { Explode(e.expr) }
28722872

2873+
/**
2874+
* Creates a new row for each element in the given array or map column.
2875+
* Unlike explode, if the array/map is null or empty then null is produced.
2876+
*
2877+
* @group collection_funcs
2878+
* @since 2.2.0
2879+
*/
2880+
def outer_explode(e: Column): Column = withExpr { new OuterExplode(e.expr) }
2881+
28732882
/**
28742883
* Creates a new row for each element with position in the given array or map column.
28752884
*
@@ -2878,6 +2887,15 @@ object functions {
28782887
*/
28792888
def posexplode(e: Column): Column = withExpr { PosExplode(e.expr) }
28802889

2890+
/**
2891+
* Creates a new row for each element with position in the given array or map column.
2892+
* Unlike posexplode, if the array/map is null or empty then the row (0, null) is produced.
2893+
*
2894+
* @group collection_funcs
2895+
* @since 2.2.0
2896+
*/
2897+
def outer_posexplode(e: Column): Column = withExpr { new OuterPosExplode(e.expr) }
2898+
28812899
/**
28822900
* Extracts json object from a json string based on json path specified, and returns json string
28832901
* of the extracted json object. It will return null if the input json string is invalid.

sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,25 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
8686
df.select(explode('intList)),
8787
Row(1) :: Row(2) :: Row(3) :: Nil)
8888
}
89+
test("single outer_explode") {
90+
val df = Seq((1, Seq(1, 2, 3)), (2, Seq())).toDF("a", "intList")
91+
checkAnswer(
92+
df.select(outer_explode('intList)),
93+
Row(1) :: Row(2) :: Row(3) :: Row(0) :: Nil)
94+
}
8995

9096
test("single posexplode") {
9197
val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
9298
checkAnswer(
9399
df.select(posexplode('intList)),
94100
Row(0, 1) :: Row(1, 2) :: Row(2, 3) :: Nil)
95101
}
102+
test("single outer_posexplode") {
103+
val df = Seq((1, Seq(1, 2, 3)), (2, Seq())).toDF("a", "intList")
104+
checkAnswer(
105+
df.select(outer_posexplode('intList)),
106+
Row(0, 1) :: Row(1, 2) :: Row(2, 3) :: Row(0, 0) :: Nil)
107+
}
96108

97109
test("explode and other columns") {
98110
val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
@@ -109,6 +121,25 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
109121
Row(1, Seq(1, 2, 3), 2) ::
110122
Row(1, Seq(1, 2, 3), 3) :: Nil)
111123
}
124+
test("outer_explode and other columns") {
125+
val df = Seq((1, Seq(1, 2, 3)), (2, Seq())).toDF("a", "intList")
126+
127+
checkAnswer(
128+
df.select($"a", outer_explode('intList)),
129+
Row(1, 1) ::
130+
Row(1, 2) ::
131+
Row(1, 3) ::
132+
Row(2, 0) ::
133+
Nil)
134+
135+
checkAnswer(
136+
df.select($"*", outer_explode('intList)),
137+
Row(1, Seq(1, 2, 3), 1) ::
138+
Row(1, Seq(1, 2, 3), 2) ::
139+
Row(1, Seq(1, 2, 3), 3) ::
140+
Row(2, Seq(), 0) ::
141+
Nil)
142+
}
112143

113144
test("aliased explode") {
114145
val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
@@ -122,13 +153,33 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
122153
Row(6) :: Nil)
123154
}
124155

156+
test("aliased outer_explode") {
157+
val df = Seq((1, Seq(1, 2, 3)), (2, Seq())).toDF("a", "intList")
158+
159+
checkAnswer(
160+
df.select(outer_explode('intList).as('int)).select('int),
161+
Row(1) :: Row(2) :: Row(3) :: Row(0) :: Nil)
162+
163+
checkAnswer(
164+
df.select(explode('intList).as('int)).select(sum('int)),
165+
Row(6) :: Nil)
166+
}
167+
125168
test("explode on map") {
126169
val df = Seq((1, Map("a" -> "b"))).toDF("a", "map")
127170

128171
checkAnswer(
129172
df.select(explode('map)),
130173
Row("a", "b"))
131174
}
175+
test("outer_explode on map") {
176+
val df = Seq((1, Map("a" -> "b")), (2, Map[String, String]()),
177+
(3, Map("c" -> "d"))).toDF("a", "map")
178+
179+
checkAnswer(
180+
df.select(outer_explode('map)),
181+
Row("a", "b") :: Row(null, null) :: Row("c", "d") :: Nil)
182+
}
132183

133184
test("explode on map with aliases") {
134185
val df = Seq((1, Map("a" -> "b"))).toDF("a", "map")
@@ -138,6 +189,14 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
138189
Row("a", "b"))
139190
}
140191

192+
test("outer_explode on map with aliases") {
193+
val df = Seq((3, None), (1, Some(Map("a" -> "b")))).toDF("a", "map")
194+
195+
checkAnswer(
196+
df.select(outer_explode('map).as("key1" :: "value1" :: Nil)).select("key1", "value1"),
197+
Row("a", "b") :: Row(null, null) :: Nil)
198+
}
199+
141200
test("self join explode") {
142201
val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
143202
val exploded = df.select(explode('intList).as('i))
@@ -206,6 +265,18 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
206265
df.selectExpr("array(struct(a), named_struct('a', b))").selectExpr("inline(*)"),
207266
Row(1) :: Row(2) :: Nil)
208267
}
268+
test("outer_inline") {
269+
val df = Seq((1, "2"), (3, "4"), (5, "6")).toDF("col1", "col2")
270+
val df2 = df.select(when('col1 === 1, null).otherwise(array(struct('col1, 'col2))).as("col1"))
271+
checkAnswer(
272+
df2.selectExpr("inline(col1)"),
273+
Row(3, "4") :: Row(5, "6") :: Nil
274+
)
275+
checkAnswer(
276+
df2.selectExpr("outer_inline(col1)"),
277+
Row(0, null) :: Row(3, "4") :: Row(5, "6") :: Nil
278+
)
279+
}
209280

210281
test("SPARK-14986: Outer lateral view with empty generate expression") {
211282
checkAnswer(

sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,9 @@ class ExpressionToSQLSuite extends SQLBuilderTest with SQLTestUtils {
102102
checkSqlGeneration("SELECT map(1, 'a', 2, 'b')")
103103
checkSqlGeneration("SELECT named_struct('c1',1,'c2',2,'c3',3)")
104104
checkSqlGeneration("SELECT nanvl(a, 5), nanvl(b, 10), nanvl(d, c) from t2")
105+
checkSqlGeneration("SELECT outer_explode(array())")
106+
checkSqlGeneration("SELECT outer_posexplode(array())")
107+
checkSqlGeneration("SELECT outer_inline(array(struct('a', 1)))")
105108
checkSqlGeneration("SELECT rand(1)")
106109
checkSqlGeneration("SELECT randn(3)")
107110
checkSqlGeneration("SELECT struct(1,2,3)")

0 commit comments

Comments
 (0)