Skip to content

Commit 5189690

Browse files
committed
[SPARK-8223][SPARK-8224] minor fix and style fix
1 parent 9434a28 commit 5189690

File tree

2 files changed

+10
-44
lines changed

2 files changed

+10
-44
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala

Lines changed: 8 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -351,17 +351,14 @@ case class Pow(left: Expression, right: Expression)
351351
}
352352
}
353353

354-
case class ShiftLeft(left: Expression, right: Expression) extends Expression {
355-
356-
override def nullable: Boolean = true
357-
358-
override def children: Seq[Expression] = Seq(left, right)
354+
case class ShiftLeft(left: Expression, right: Expression) extends BinaryExpression {
359355

360356
override def checkInputDataTypes(): TypeCheckResult = {
361357
(left.dataType, right.dataType) match {
362358
case (NullType, _) | (_, NullType) => return TypeCheckResult.TypeCheckSuccess
363359
case (_, IntegerType) => left.dataType match {
364-
case LongType | IntegerType | ShortType | ByteType => TypeCheckResult.TypeCheckSuccess
360+
case LongType | IntegerType | ShortType | ByteType =>
361+
return TypeCheckResult.TypeCheckSuccess
365362
case _ => // failed
366363
}
367364
case _ => // failed
@@ -399,37 +396,20 @@ case class ShiftLeft(left: Expression, right: Expression) extends Expression {
399396
}
400397

401398
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
402-
val eval1 = left.gen(ctx)
403-
val eval2 = right.gen(ctx)
404-
s"""
405-
${eval1.code}
406-
boolean ${ev.isNull} = ${eval1.isNull};
407-
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
408-
if (!${ev.isNull}) {
409-
${eval2.code}
410-
if (!${eval2.isNull}) {
411-
${ev.primitive} = ${eval1.primitive} << ${eval2.primitive};
412-
} else {
413-
${ev.isNull} = true;
414-
}
415-
}
416-
"""
399+
nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left << $right;")
417400
}
418401

419402
override def toString: String = s"ShiftLeft($left, $right)"
420403
}
421404

422-
case class ShiftRight(left: Expression, right: Expression) extends Expression {
423-
424-
override def nullable: Boolean = true
425-
426-
override def children: Seq[Expression] = Seq(left, right)
405+
case class ShiftRight(left: Expression, right: Expression) extends BinaryExpression {
427406

428407
override def checkInputDataTypes(): TypeCheckResult = {
429408
(left.dataType, right.dataType) match {
430409
case (NullType, _) | (_, NullType) => return TypeCheckResult.TypeCheckSuccess
431410
case (_, IntegerType) => left.dataType match {
432-
case LongType | IntegerType | ShortType | ByteType => return TypeCheckResult.TypeCheckSuccess
411+
case LongType | IntegerType | ShortType | ByteType =>
412+
return TypeCheckResult.TypeCheckSuccess
433413
case _ => // failed
434414
}
435415
case _ => // failed
@@ -467,21 +447,7 @@ case class ShiftRight(left: Expression, right: Expression) extends Expression {
467447
}
468448

469449
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
470-
val eval1 = left.gen(ctx)
471-
val eval2 = right.gen(ctx)
472-
s"""
473-
${eval1.code}
474-
boolean ${ev.isNull} = ${eval1.isNull};
475-
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
476-
if (!${ev.isNull}) {
477-
${eval2.code}
478-
if (!${eval2.isNull}) {
479-
${ev.primitive} = ${eval1.primitive} >> ${eval2.primitive};
480-
} else {
481-
${ev.isNull} = true;
482-
}
483-
}
484-
"""
450+
nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left >> $right;")
485451
}
486452

487453
override def toString: String = s"ShiftRight($left, $right)"

sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,11 +248,11 @@ class MathExpressionsSuite extends QueryTest {
248248
test("log1p") {
249249
testOneToOneNonNegativeMathFunction(log1p, math.log1p)
250250
}
251-
251+
252252
test("shift left") {
253253
val df = Seq[(Long, Integer, Short, Byte, Integer, Integer)]((21, 21, 21, 21, 21, null))
254254
.toDF("a", "b", "c", "d", "e", "f")
255-
255+
256256
checkAnswer(
257257
df.select(
258258
shiftLeft('a, 1), shiftLeft('b, 1), shiftLeft('c, 1), shiftLeft('d, 1),

0 commit comments

Comments
 (0)