Skip to content

Commit a50d42e

Browse files
mn-mikkemn-mikke
authored andcommitted
[SPARK-23821][SQL] Code-styling improvements
1 parent eeab727 commit a50d42e

File tree

3 files changed

+58
-67
lines changed

3 files changed

+58
-67
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 53 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,8 @@ case class ArrayContains(left: Expression, right: Expression)
299299
Examples:
300300
> SELECT _FUNC_(array(array(1, 2), array(3, 4));
301301
[1,2,3,4]
302-
""")
302+
""",
303+
since = "2.4.0")
303304
case class Flatten(child: Expression) extends UnaryExpression {
304305

305306
override def nullable: Boolean = child.nullable || dataType.containsNull
@@ -310,18 +311,14 @@ case class Flatten(child: Expression) extends UnaryExpression {
310311
.elementType.asInstanceOf[ArrayType]
311312
}
312313

313-
override def checkInputDataTypes(): TypeCheckResult = {
314-
if (
315-
ArrayType.acceptsType(child.dataType) &&
316-
ArrayType.acceptsType(child.dataType.asInstanceOf[ArrayType].elementType)
317-
) {
314+
override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
315+
case ArrayType(_: ArrayType, _) =>
318316
TypeCheckResult.TypeCheckSuccess
319-
} else {
317+
case _ =>
320318
TypeCheckResult.TypeCheckFailure(
321319
s"The argument should be an array of arrays, " +
322320
s"but '${child.sql}' is of ${child.dataType.simpleString} type."
323321
)
324-
}
325322
}
326323

327324
override def nullSafeEval(array: Any): Any = {
@@ -339,8 +336,7 @@ case class Flatten(child: Expression) extends UnaryExpression {
339336

340337
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
341338
nullSafeCodeGen(ctx, ev, c => {
342-
val code =
343-
if (CodeGenerator.isPrimitiveType(dataType.elementType)) {
339+
val code = if (CodeGenerator.isPrimitiveType(dataType.elementType)) {
344340
genCodeForConcatOfPrimitiveElements(ctx, c, ev.value)
345341
} else {
346342
genCodeForConcatOfComplexElements(ctx, c, ev.value)
@@ -354,26 +350,25 @@ case class Flatten(child: Expression) extends UnaryExpression {
354350
childVariableName: String,
355351
coreLogic: String): String = {
356352
s"""
357-
|for(int z=0; z < $childVariableName.numElements(); z++) {
358-
| ${ev.isNull} |= $childVariableName.isNullAt(z);
359-
|}
360-
|if(!${ev.isNull}) {
361-
| $coreLogic
362-
|}
363-
""".stripMargin
353+
|for(int z=0; z < $childVariableName.numElements(); z++) {
354+
| ${ev.isNull} |= $childVariableName.isNullAt(z);
355+
|}
356+
|if(!${ev.isNull}) {
357+
| $coreLogic
358+
|}
359+
""".stripMargin
364360
}
365361

366362
private def genCodeForNumberOfElements(
367363
ctx: CodegenContext,
368364
childVariableName: String) : (String, String) = {
369365
val variableName = ctx.freshName("numElements")
370-
val code =
371-
s"""
372-
|int $variableName = 0;
373-
|for(int z=0; z < $childVariableName.numElements(); z++) {
374-
| $variableName += $childVariableName.getArray(z).numElements();
375-
|}
376-
""".stripMargin
366+
val code = s"""
367+
|int $variableName = 0;
368+
|for(int z=0; z < $childVariableName.numElements(); z++) {
369+
| $variableName += $childVariableName.getArray(z).numElements();
370+
|}
371+
""".stripMargin
377372
(code, variableName)
378373
}
379374

@@ -400,28 +395,28 @@ case class Flatten(child: Expression) extends UnaryExpression {
400395
val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
401396

402397
s"""
403-
|$numElemCode
404-
|$unsafeArraySizeInBytes
405-
|byte[] $arrayName = new byte[$arraySizeName];
406-
|UnsafeArrayData $tempArrayDataName = new UnsafeArrayData();
407-
|Platform.putLong($arrayName, $baseOffset, $numElemName);
408-
|$tempArrayDataName.pointTo($arrayName, $baseOffset, $arraySizeName);
409-
|int $counter = 0;
410-
|for(int k=0; k < $childVariableName.numElements(); k++) {
411-
| ArrayData arr = $childVariableName.getArray(k);
412-
| for(int l = 0; l < arr.numElements(); l++) {
413-
| if(arr.isNullAt(l)) {
414-
| $tempArrayDataName.setNullAt($counter);
415-
| } else {
416-
| $tempArrayDataName.set$primitiveValueTypeName(
417-
| $counter,
418-
| arr.get$primitiveValueTypeName(l)
419-
| );
420-
| }
421-
| $counter++;
422-
| }
423-
|}
424-
|$arrayDataName = $tempArrayDataName;
398+
|$numElemCode
399+
|$unsafeArraySizeInBytes
400+
|byte[] $arrayName = new byte[$arraySizeName];
401+
|UnsafeArrayData $tempArrayDataName = new UnsafeArrayData();
402+
|Platform.putLong($arrayName, $baseOffset, $numElemName);
403+
|$tempArrayDataName.pointTo($arrayName, $baseOffset, $arraySizeName);
404+
|int $counter = 0;
405+
|for(int k=0; k < $childVariableName.numElements(); k++) {
406+
| ArrayData arr = $childVariableName.getArray(k);
407+
| for(int l = 0; l < arr.numElements(); l++) {
408+
| if(arr.isNullAt(l)) {
409+
| $tempArrayDataName.setNullAt($counter);
410+
| } else {
411+
| $tempArrayDataName.set$primitiveValueTypeName(
412+
| $counter,
413+
| arr.get$primitiveValueTypeName(l)
414+
| );
415+
| }
416+
| $counter++;
417+
| }
418+
|}
419+
|$arrayDataName = $tempArrayDataName;
425420
""".stripMargin
426421
}
427422

@@ -435,18 +430,18 @@ case class Flatten(child: Expression) extends UnaryExpression {
435430
val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, childVariableName)
436431

437432
s"""
438-
|$numElemCode
439-
|Object[] $arrayName = new Object[$numElemName];
440-
|int $counter = 0;
441-
|for(int k=0; k < $childVariableName.numElements(); k++) {
442-
| Object[] arr = $childVariableName.getArray(k).array();
443-
| for(int l = 0; l < arr.length; l++) {
444-
| $arrayName[$counter] = arr[l];
445-
| $counter++;
446-
| }
447-
|}
448-
|$arrayDataName = new $genericArrayClass($arrayName);
449-
""".stripMargin
433+
|$numElemCode
434+
|Object[] $arrayName = new Object[$numElemName];
435+
|int $counter = 0;
436+
|for(int k=0; k < $childVariableName.numElements(); k++) {
437+
| Object[] arr = $childVariableName.getArray(k).array();
438+
| for(int l = 0; l < arr.length; l++) {
439+
| $arrayName[$counter] = arr[l];
440+
| $counter++;
441+
| }
442+
|}
443+
|$arrayDataName = new $genericArrayClass($arrayName);
444+
""".stripMargin
450445
}
451446

452447
override def prettyName: String = "flatten"

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3306,7 +3306,7 @@ object functions {
33063306
* @group collection_funcs
33073307
* @since 2.4.0
33083308
*/
3309-
def flatten(e: Column): Column = withExpr{ Flatten(e.expr) }
3309+
def flatten(e: Column): Column = withExpr { Flatten(e.expr) }
33103310

33113311
/**
33123312
* Returns an unordered array containing the keys of the map.

sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -436,15 +436,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
436436
Row(Seq.empty),
437437
Row(null),
438438
Row(null),
439-
Row(null)
440-
)
439+
Row(null))
441440

442441
checkAnswer(intDF.select(flatten($"i")), intDFResult)
443442
checkAnswer(intDF.selectExpr("flatten(i)"), intDFResult)
444443
checkAnswer(
445444
oneRowDF.selectExpr("flatten(array(arr, array(null, 5), array(6, null)))"),
446-
Seq(Row(Seq(1, 2, 3, null, 5, 6, null)))
447-
)
445+
Seq(Row(Seq(1, 2, 3, null, 5, 6, null))))
448446

449447
// Test cases with complex types
450448
val strDF = Seq(
@@ -468,15 +466,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
468466
Row(Seq.empty),
469467
Row(null),
470468
Row(null),
471-
Row(null)
472-
)
469+
Row(null))
473470

474471
checkAnswer(strDF.select(flatten($"s")), strDFResult)
475472
checkAnswer(strDF.selectExpr("flatten(s)"), strDFResult)
476473
checkAnswer(
477474
oneRowDF.selectExpr("flatten(array(array(arr, arr), array(arr)))"),
478-
Seq(Row(Seq(Seq(1, 2, 3), Seq(1, 2, 3), Seq(1, 2, 3))))
479-
)
475+
Seq(Row(Seq(Seq(1, 2, 3), Seq(1, 2, 3), Seq(1, 2, 3)))))
480476

481477
// Error test cases
482478
intercept[AnalysisException] {

0 commit comments

Comments
 (0)