From ac7fe9d85ba8594b1c3ed8d8b911b95e1670e468 Mon Sep 17 00:00:00 2001 From: Tarek Auel Date: Wed, 1 Jul 2015 21:05:51 -0700 Subject: [PATCH 1/6] [SPARK-8223][SPARK-8224] right and left bit shift --- python/pyspark/sql/functions.py | 27 ++++ .../catalyst/analysis/FunctionRegistry.scala | 2 + .../spark/sql/catalyst/expressions/math.scala | 136 ++++++++++++++++++ .../expressions/MathFunctionsSuite.scala | 28 +++- .../org/apache/spark/sql/functions.scala | 38 +++++ .../spark/sql/MathExpressionsSuite.scala | 34 +++++ 6 files changed, 264 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 4e2be88e9e3b9..fa60c5638d561 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -412,6 +412,33 @@ def sha2(col, numBits): return Column(jc) +@since(1.5) +def shiftLeft(col, numBits): + """Shift the the given value numBits left. Returns int for tinyint, smallint and int and + bigint for bigint a. + + >>> sqlContext.createDataFrame([(21,)], ['a']).select(shiftLeft('a', 1).alias('r')).collect() + [Row(r=42)] + + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.shiftLeft(_to_java_column(col), numBits) + return Column(jc) + + +@since(1.5) +def shiftRight(col, numBits): + """Shift the the given value numBits right. Returns int for tinyint, smallint and int and + bigint for bigint a. + + >>> sqlContext.createDataFrame([(42,)], ['a']).select(shiftRight('a', 1).alias('r')).collect() + [Row(r=21)] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.shiftRight(_to_java_column(col), numBits) + return Column(jc) + + @since(1.4) def sparkPartitionId(): """A column for partition ID of the Spark task. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index d53eaedda56b0..f1ee5d28186a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -125,6 +125,8 @@ object FunctionRegistry { expression[Pow]("power"), expression[UnaryPositive]("positive"), expression[Rint]("rint"), + expression[ShiftLeft]("shiftleft"), + expression[ShiftRight]("shiftright"), expression[Signum]("sign"), expression[Signum]("signum"), expression[Sin]("sin"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index da63f2fa970cf..1097fea776e17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -351,6 +351,142 @@ case class Pow(left: Expression, right: Expression) } } +case class ShiftLeft(left: Expression, right: Expression) extends Expression { + + override def nullable: Boolean = true + + override def children: Seq[Expression] = Seq(left, right) + + override def checkInputDataTypes(): TypeCheckResult = { + (left.dataType, right.dataType) match { + case (NullType, _) | (_, NullType) => return TypeCheckResult.TypeCheckSuccess + case (_, IntegerType) => left.dataType match { + case LongType | IntegerType | ShortType | ByteType => TypeCheckResult.TypeCheckSuccess + case _ => // failed + } + case _ => // failed + } + TypeCheckResult.TypeCheckFailure( + s"ShiftLeft expects long, integer, short or byte value as first argument and an " + + s"integer value as second argument, not (${left.dataType}, ${right.dataType})") + } + + override def eval(input: InternalRow): Any = { + val valueLeft = left.eval(input) + if (valueLeft != null) { + val valueRight = right.eval(input) + if (valueRight != null) { + valueLeft match { + case l: Long => l << valueRight.asInstanceOf[Integer] + case i: Integer => i << valueRight.asInstanceOf[Integer] + case s: Short => s << valueRight.asInstanceOf[Integer] + case b: Byte => b << valueRight.asInstanceOf[Integer] + } + } else { + null + } + } else { + null + } + } + + override def dataType: DataType = { + left.dataType match { + case LongType => LongType + case IntegerType | ShortType | ByteType => IntegerType + case _ => NullType + } + } + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + s""" + ${eval1.code} + boolean ${ev.isNull} = ${eval1.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${eval2.code} + if (!${eval2.isNull}) { + ${ev.primitive} = ${eval1.primitive} << ${eval2.primitive}; + } else { + ${ev.isNull} = true; + } + } + """ + } + + override def toString: String = s"ShiftLeft($left, $right)" +} + +case class ShiftRight(left: Expression, right: Expression) extends Expression { + + override def nullable: Boolean = true + + override def children: Seq[Expression] = Seq(left, right) + + override def checkInputDataTypes(): TypeCheckResult = { + (left.dataType, right.dataType) match { + case (NullType, _) | (_, NullType) => return TypeCheckResult.TypeCheckSuccess + case (_, IntegerType) => left.dataType match { + case LongType | IntegerType | ShortType | ByteType => return TypeCheckResult.TypeCheckSuccess + case _ => // failed + } + case _ => // failed + } + TypeCheckResult.TypeCheckFailure( + s"ShiftRight expects long, integer, short or byte value as first argument and an " + + s"integer value as second argument, not (${left.dataType}, ${right.dataType})") + } + + override def eval(input: InternalRow): Any = { + val valueLeft = left.eval(input) + if (valueLeft != null) { + val valueRight = right.eval(input) + if (valueRight != null) { + valueLeft match { + case l: Long => l >> valueRight.asInstanceOf[Integer] + case i: Integer => i >> valueRight.asInstanceOf[Integer] + case s: Short => s >> valueRight.asInstanceOf[Integer] + case b: Byte => b >> valueRight.asInstanceOf[Integer] + } + } else { + null + } + } else { + null + } + } + + override def dataType: DataType = { + left.dataType match { + case LongType => LongType + case IntegerType | ShortType | ByteType => IntegerType + case _ => NullType + } + } + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + s""" + ${eval1.code} + boolean ${ev.isNull} = ${eval1.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${eval2.code} + if (!${eval2.isNull}) { + ${ev.primitive} = ${eval1.primitive} >> ${eval2.primitive}; + } else { + ${ev.isNull} = true; + } + } + """ + } + + override def toString: String = s"ShiftRight($left, $right)" +} + case class Hypot(left: Expression, right: Expression) extends BinaryMathExpression(math.hypot, "HYPOT") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index b932d4ab850c7..5fc1830db1aa7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.types.{DataType, DoubleType, LongType} +import org.apache.spark.sql.types.{IntegerType, DataType, DoubleType, LongType} class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -225,6 +225,32 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNull = true) } + test("shift left") { + checkEvaluation(ShiftLeft(Literal.create(null, IntegerType), Literal(1)), null) + checkEvaluation(ShiftLeft(Literal(21), Literal.create(null, IntegerType)), null) + checkEvaluation( + ShiftLeft(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null) + checkEvaluation(ShiftLeft(Literal(21), Literal(1)), 42) + checkEvaluation(ShiftLeft(Literal(21.toByte), Literal(1)), 42) + checkEvaluation(ShiftLeft(Literal(21.toShort), Literal(1)), 42) + checkEvaluation(ShiftLeft(Literal(21.toLong), Literal(1)), 42.toLong) + + checkEvaluation(ShiftLeft(Literal(-21.toLong), Literal(1)), -42.toLong) + } + + test("shift right") { + checkEvaluation(ShiftRight(Literal.create(null, IntegerType), Literal(1)), null) + checkEvaluation(ShiftRight(Literal(42), Literal.create(null, IntegerType)), null) + checkEvaluation( + ShiftRight(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null) + checkEvaluation(ShiftRight(Literal(42), Literal(1)), 21) + checkEvaluation(ShiftRight(Literal(42.toByte), Literal(1)), 21) + checkEvaluation(ShiftRight(Literal(42.toShort), Literal(1)), 21) + checkEvaluation(ShiftRight(Literal(42.toLong), Literal(1)), 21.toLong) + + checkEvaluation(ShiftRight(Literal(-42.toLong), Literal(1)), -21.toLong) + } + test("hex") { checkEvaluation(Hex(Literal(28)), "1C") checkEvaluation(Hex(Literal(-28)), "FFFFFFFFFFFFFFE4") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 5767668dd339b..22158ca39c320 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1280,6 +1280,44 @@ object functions { */ def rint(columnName: String): Column = rint(Column(columnName)) + /** + * Shift the the given value numBits left. Returns int for tinyint, smallint and int and + * bigint for bigint a. + * + * @group math_funcs + * @since 1.5.0 + */ + def shiftLeft(e: Column, numBits: Integer): Column = ShiftLeft(e.expr, lit(numBits).expr) + + /** + * Shift the the given value numBits left. Returns int for tinyint, smallint and int and + * bigint for bigint a. + * + * @group math_funcs + * @since 1.5.0 + */ + def shiftLeft(columnName: String, numBits: Integer): Column = + shiftLeft(Column(columnName), numBits) + + /** + * Bitwise right shift of the given value. Returns int for tinyint, smallint and int and + * bigint for bigint a. + * + * @group math_funcs + * @since 1.5.0 + */ + def shiftRight(e: Column, numBits: Integer): Column = ShiftRight(e.expr, lit(numBits).expr) + + /** + * Shift the the given value numBits right. Returns int for tinyint, smallint and int and + * bigint for bigint a. + * + * @group math_funcs + * @since 1.5.0 + */ + def shiftRight(columnName: String, numBits: Integer): Column = + shiftRight(Column(columnName), numBits) + /** * Computes the signum of the given value. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index d6331aa4ff09e..b4907bb295418 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -248,6 +248,40 @@ class MathExpressionsSuite extends QueryTest { test("log1p") { testOneToOneNonNegativeMathFunction(log1p, math.log1p) } + + test("shift left") { + val df = Seq[(Long, Integer, Short, Byte, Integer, Integer)]((21, 21, 21, 21, 21, null)) + .toDF("a", "b", "c", "d", "e", "f") + + checkAnswer( + df.select( + shiftLeft('a, 1), shiftLeft('b, 1), shiftLeft('c, 1), shiftLeft('d, 1), + shiftLeft('e, null), shiftLeft('f, 1)), + Row(42.toLong, 42, 42.toShort, 42.toByte, null, null)) + + checkAnswer( + df.selectExpr( + "shiftLeft(a, 1)", "shiftLeft(b, 1)", "shiftLeft(b, 1)", "shiftLeft(d, 1)", + "shiftLeft(e, null)", "shiftLeft(f, 1)"), + Row(42.toLong, 42, 42.toShort, 42.toByte, null, null)) + } + + test("shift right") { + val df = Seq[(Long, Integer, Short, Byte, Integer, Integer)]((42, 42, 42, 42, 42, null)) + .toDF("a", "b", "c", "d", "e", "f") + + checkAnswer( + df.select( + shiftRight('a, 1), shiftRight('b, 1), shiftRight('c, 1), shiftRight('d, 1), + shiftRight('e, null), shiftRight('f, 1)), + Row(21.toLong, 21, 21.toShort, 21.toByte, null, null)) + + checkAnswer( + df.selectExpr( + "shiftRight(a, 1)", "shiftRight(b, 1)", "shiftRight(c, 1)", "shiftRight(d, 1)", + "shiftRight(e, null)", "shiftRight(f, 1)"), + Row(21.toLong, 21, 21.toShort, 21.toByte, null, null)) + } test("binary log") { val df = Seq[(Integer, Integer)]((123, null)).toDF("a", "b") From 44ee324f4a3494bba9d98eb0754291a794617032 Mon Sep 17 00:00:00 2001 From: Tarek Auel Date: Wed, 1 Jul 2015 21:10:15 -0700 Subject: [PATCH 2/6] [SPARK-8223][SPARK-8224] docu fix --- python/pyspark/sql/functions.py | 1 - sql/core/src/main/scala/org/apache/spark/sql/functions.scala | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index fa60c5638d561..f4c5ec5d6d131 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -419,7 +419,6 @@ def shiftLeft(col, numBits): >>> sqlContext.createDataFrame([(21,)], ['a']).select(shiftLeft('a', 1).alias('r')).collect() [Row(r=42)] - """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.shiftLeft(_to_java_column(col), numBits) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 22158ca39c320..19274cf544a11 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1300,7 +1300,7 @@ object functions { shiftLeft(Column(columnName), numBits) /** - * Bitwise right shift of the given value. Returns int for tinyint, smallint and int and + * Shift the the given value numBits right. Returns int for tinyint, smallint and int and * bigint for bigint a. * * @group math_funcs From 5189690adab45244dd917d45a8aab429a30b96df Mon Sep 17 00:00:00 2001 From: Tarek Auel Date: Wed, 1 Jul 2015 21:45:30 -0700 Subject: [PATCH 3/6] [SPARK-8223][SPARK-8224] minor fix and style fix --- .../spark/sql/catalyst/expressions/math.scala | 50 +++---------------- .../spark/sql/MathExpressionsSuite.scala | 4 +- 2 files changed, 10 insertions(+), 44 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 5725405451435..e32d949620c29 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -351,17 +351,14 @@ case class Pow(left: Expression, right: Expression) } } -case class ShiftLeft(left: Expression, right: Expression) extends Expression { - - override def nullable: Boolean = true - - override def children: Seq[Expression] = Seq(left, right) +case class ShiftLeft(left: Expression, right: Expression) extends BinaryExpression { override def checkInputDataTypes(): TypeCheckResult = { (left.dataType, right.dataType) match { case (NullType, _) | (_, NullType) => return TypeCheckResult.TypeCheckSuccess case (_, IntegerType) => left.dataType match { - case LongType | IntegerType | ShortType | ByteType => TypeCheckResult.TypeCheckSuccess + case LongType | IntegerType | ShortType | ByteType => + return TypeCheckResult.TypeCheckSuccess case _ => // failed } case _ => // failed @@ -399,37 +396,20 @@ case class ShiftLeft(left: Expression, right: Expression) extends Expression { } override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val eval1 = left.gen(ctx) - val eval2 = right.gen(ctx) - s""" - ${eval1.code} - boolean ${ev.isNull} = ${eval1.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${eval2.code} - if (!${eval2.isNull}) { - ${ev.primitive} = ${eval1.primitive} << ${eval2.primitive}; - } else { - ${ev.isNull} = true; - } - } - """ + nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left << $right;") } override def toString: String = s"ShiftLeft($left, $right)" } -case class ShiftRight(left: Expression, right: Expression) extends Expression { - - override def nullable: Boolean = true - - override def children: Seq[Expression] = Seq(left, right) +case class ShiftRight(left: Expression, right: Expression) extends BinaryExpression { override def checkInputDataTypes(): TypeCheckResult = { (left.dataType, right.dataType) match { case (NullType, _) | (_, NullType) => return TypeCheckResult.TypeCheckSuccess case (_, IntegerType) => left.dataType match { - case LongType | IntegerType | ShortType | ByteType => return TypeCheckResult.TypeCheckSuccess + case LongType | IntegerType | ShortType | ByteType => + return TypeCheckResult.TypeCheckSuccess case _ => // failed } case _ => // failed @@ -467,21 +447,7 @@ case class ShiftRight(left: Expression, right: Expression) extends Expression { } override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val eval1 = left.gen(ctx) - val eval2 = right.gen(ctx) - s""" - ${eval1.code} - boolean ${ev.isNull} = ${eval1.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${eval2.code} - if (!${eval2.isNull}) { - ${ev.primitive} = ${eval1.primitive} >> ${eval2.primitive}; - } else { - ${ev.isNull} = true; - } - } - """ + nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left >> $right;") } override def toString: String = s"ShiftRight($left, $right)" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index b4907bb295418..ebe7919f7e30b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -248,11 +248,11 @@ class MathExpressionsSuite extends QueryTest { test("log1p") { testOneToOneNonNegativeMathFunction(log1p, math.log1p) } - + test("shift left") { val df = Seq[(Long, Integer, Short, Byte, Integer, Integer)]((21, 21, 21, 21, 21, null)) .toDF("a", "b", "c", "d", "e", "f") - + checkAnswer( df.select( shiftLeft('a, 1), shiftLeft('b, 1), shiftLeft('c, 1), shiftLeft('d, 1), From f62870697215357c96525311d8281c4832526cd7 Mon Sep 17 00:00:00 2001 From: Tarek Auel Date: Wed, 1 Jul 2015 23:48:47 -0700 Subject: [PATCH 4/6] [SPARK-8223][SPARK-8224] removed toString; updated function description --- python/pyspark/sql/functions.py | 6 ++---- .../spark/sql/catalyst/expressions/math.scala | 4 ---- .../scala/org/apache/spark/sql/functions.scala | 16 ++++++++-------- 3 files changed, 10 insertions(+), 16 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 1fa24095dbaaf..bccde6083ca3c 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -414,8 +414,7 @@ def sha2(col, numBits): @since(1.5) def shiftLeft(col, numBits): - """Shift the the given value numBits left. Returns int for tinyint, smallint and int and - bigint for bigint a. + """Shift the the given value numBits left. >>> sqlContext.createDataFrame([(21,)], ['a']).select(shiftLeft('a', 1).alias('r')).collect() [Row(r=42)] @@ -427,8 +426,7 @@ def shiftLeft(col, numBits): @since(1.5) def shiftRight(col, numBits): - """Shift the the given value numBits right. Returns int for tinyint, smallint and int and - bigint for bigint a. + """Shift the the given value numBits right. >>> sqlContext.createDataFrame([(42,)], ['a']).select(shiftRight('a', 1).alias('r')).collect() [Row(r=21)] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 49ec80edbeef8..7504c6a066657 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -398,8 +398,6 @@ case class ShiftLeft(left: Expression, right: Expression) extends BinaryExpressi override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left << $right;") } - - override def toString: String = s"ShiftLeft($left, $right)" } case class ShiftRight(left: Expression, right: Expression) extends BinaryExpression { @@ -449,8 +447,6 @@ case class ShiftRight(left: Expression, right: Expression) extends BinaryExpress override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left >> $right;") } - - override def toString: String = s"ShiftRight($left, $right)" } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 7298a108c358b..6555340d57e1c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1299,8 +1299,8 @@ object functions { def rint(columnName: String): Column = rint(Column(columnName)) /** - * Shift the the given value numBits left. Returns int for tinyint, smallint and int and - * bigint for bigint a. + * Shift the the given value numBits left. If the given value is a long value, this function + * will return a long value else it will return an integer value. * * @group math_funcs * @since 1.5.0 @@ -1308,8 +1308,8 @@ object functions { def shiftLeft(e: Column, numBits: Integer): Column = ShiftLeft(e.expr, lit(numBits).expr) /** - * Shift the the given value numBits left. Returns int for tinyint, smallint and int and - * bigint for bigint a. + * Shift the the given value numBits left. If the given value is a long value, this function + * will return a long value else it will return an integer value. * * @group math_funcs * @since 1.5.0 @@ -1318,8 +1318,8 @@ object functions { shiftLeft(Column(columnName), numBits) /** - * Shift the the given value numBits right. Returns int for tinyint, smallint and int and - * bigint for bigint a. + * Shift the the given value numBits right. If the given value is a long value, it will return + * a long value else it will return an integer value. * * @group math_funcs * @since 1.5.0 @@ -1327,8 +1327,8 @@ object functions { def shiftRight(e: Column, numBits: Integer): Column = ShiftRight(e.expr, lit(numBits).expr) /** - * Shift the the given value numBits right. Returns int for tinyint, smallint and int and - * bigint for bigint a. + * Shift the the given value numBits right. If the given value is a long value, it will return + * a long value else it will return an integer value. * * @group math_funcs * @since 1.5.0 From f3f64e6677d5f6f7576b5a619ce081d1c8f0da07 Mon Sep 17 00:00:00 2001 From: Tarek Auel Date: Thu, 2 Jul 2015 00:02:15 -0700 Subject: [PATCH 5/6] [SPARK-8223][SPARK-8224] Integer -> Int --- .../src/main/scala/org/apache/spark/sql/functions.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 6555340d57e1c..a5b68286853ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1305,7 +1305,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def shiftLeft(e: Column, numBits: Integer): Column = ShiftLeft(e.expr, lit(numBits).expr) + def shiftLeft(e: Column, numBits: Int): Column = ShiftLeft(e.expr, lit(numBits).expr) /** * Shift the the given value numBits left. If the given value is a long value, this function @@ -1314,7 +1314,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def shiftLeft(columnName: String, numBits: Integer): Column = + def shiftLeft(columnName: String, numBits: Int): Column = shiftLeft(Column(columnName), numBits) /** @@ -1324,7 +1324,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def shiftRight(e: Column, numBits: Integer): Column = ShiftRight(e.expr, lit(numBits).expr) + def shiftRight(e: Column, numBits: Int): Column = ShiftRight(e.expr, lit(numBits).expr) /** * Shift the the given value numBits right. If the given value is a long value, it will return @@ -1333,7 +1333,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def shiftRight(columnName: String, numBits: Integer): Column = + def shiftRight(columnName: String, numBits: Int): Column = shiftRight(Column(columnName), numBits) /** From 8023bb558da36c6bc7cd2fbe9b5240297a69679c Mon Sep 17 00:00:00 2001 From: Tarek Auel Date: Thu, 2 Jul 2015 01:18:52 -0700 Subject: [PATCH 6/6] [SPARK-8223][SPARK-8224] fixed test --- .../apache/spark/sql/MathExpressionsSuite.scala | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 6f4ca5dfc64ed..4c5696deaff81 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -266,14 +266,14 @@ class MathExpressionsSuite extends QueryTest { checkAnswer( df.select( shiftLeft('a, 1), shiftLeft('b, 1), shiftLeft('c, 1), shiftLeft('d, 1), - shiftLeft('e, null), shiftLeft('f, 1)), - Row(42.toLong, 42, 42.toShort, 42.toByte, null, null)) + shiftLeft('f, 1)), + Row(42.toLong, 42, 42.toShort, 42.toByte, null)) checkAnswer( df.selectExpr( "shiftLeft(a, 1)", "shiftLeft(b, 1)", "shiftLeft(b, 1)", "shiftLeft(d, 1)", - "shiftLeft(e, null)", "shiftLeft(f, 1)"), - Row(42.toLong, 42, 42.toShort, 42.toByte, null, null)) + "shiftLeft(f, 1)"), + Row(42.toLong, 42, 42.toShort, 42.toByte, null)) } test("shift right") { @@ -283,14 +283,14 @@ class MathExpressionsSuite extends QueryTest { checkAnswer( df.select( shiftRight('a, 1), shiftRight('b, 1), shiftRight('c, 1), shiftRight('d, 1), - shiftRight('e, null), shiftRight('f, 1)), - Row(21.toLong, 21, 21.toShort, 21.toByte, null, null)) + shiftRight('f, 1)), + Row(21.toLong, 21, 21.toShort, 21.toByte, null)) checkAnswer( df.selectExpr( "shiftRight(a, 1)", "shiftRight(b, 1)", "shiftRight(c, 1)", "shiftRight(d, 1)", - "shiftRight(e, null)", "shiftRight(f, 1)"), - Row(21.toLong, 21, 21.toShort, 21.toByte, null, null)) + "shiftRight(f, 1)"), + Row(21.toLong, 21, 21.toShort, 21.toByte, null)) } test("binary log") {