diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f9a15d4a66309..bccde6083ca3c 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -412,6 +412,30 @@ def sha2(col, numBits): return Column(jc) +@since(1.5) +def shiftLeft(col, numBits): + """Shift the the given value numBits left. + + >>> sqlContext.createDataFrame([(21,)], ['a']).select(shiftLeft('a', 1).alias('r')).collect() + [Row(r=42)] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.shiftLeft(_to_java_column(col), numBits) + return Column(jc) + + +@since(1.5) +def shiftRight(col, numBits): + """Shift the the given value numBits right. + + >>> sqlContext.createDataFrame([(42,)], ['a']).select(shiftRight('a', 1).alias('r')).collect() + [Row(r=21)] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.shiftRight(_to_java_column(col), numBits) + return Column(jc) + + @since(1.4) def sparkPartitionId(): """A column for partition ID of the Spark task. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 6f04298d4711b..aa051b163363a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -125,6 +125,8 @@ object FunctionRegistry { expression[Pow]("power"), expression[UnaryPositive]("positive"), expression[Rint]("rint"), + expression[ShiftLeft]("shiftleft"), + expression[ShiftRight]("shiftright"), expression[Signum]("sign"), expression[Signum]("signum"), expression[Sin]("sin"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 8633eb06ffee4..7504c6a066657 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -351,6 +351,104 @@ case class Pow(left: Expression, right: Expression) } } +case class ShiftLeft(left: Expression, right: Expression) extends BinaryExpression { + + override def checkInputDataTypes(): TypeCheckResult = { + (left.dataType, right.dataType) match { + case (NullType, _) | (_, NullType) => return TypeCheckResult.TypeCheckSuccess + case (_, IntegerType) => left.dataType match { + case LongType | IntegerType | ShortType | ByteType => + return TypeCheckResult.TypeCheckSuccess + case _ => // failed + } + case _ => // failed + } + TypeCheckResult.TypeCheckFailure( + s"ShiftLeft expects long, integer, short or byte value as first argument and an " + + s"integer value as second argument, not (${left.dataType}, ${right.dataType})") + } + + override def eval(input: InternalRow): Any = { + val valueLeft = left.eval(input) + if (valueLeft != null) { + val valueRight = right.eval(input) + if (valueRight != null) { + valueLeft match { + case l: Long => l << valueRight.asInstanceOf[Integer] + case i: Integer => i << valueRight.asInstanceOf[Integer] + case s: Short => s << valueRight.asInstanceOf[Integer] + case b: Byte => b << valueRight.asInstanceOf[Integer] + } + } else { + null + } + } else { + null + } + } + + override def dataType: DataType = { + left.dataType match { + case LongType => LongType + case IntegerType | ShortType | ByteType => IntegerType + case _ => NullType + } + } + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left << $right;") + } +} + +case class ShiftRight(left: Expression, right: Expression) extends BinaryExpression { + + override def checkInputDataTypes(): TypeCheckResult = { + (left.dataType, right.dataType) match { + case (NullType, _) | (_, NullType) => return TypeCheckResult.TypeCheckSuccess + case (_, IntegerType) => left.dataType match { + case LongType | IntegerType | ShortType | ByteType => + return TypeCheckResult.TypeCheckSuccess + case _ => // failed + } + case _ => // failed + } + TypeCheckResult.TypeCheckFailure( + s"ShiftRight expects long, integer, short or byte value as first argument and an " + + s"integer value as second argument, not (${left.dataType}, ${right.dataType})") + } + + override def eval(input: InternalRow): Any = { + val valueLeft = left.eval(input) + if (valueLeft != null) { + val valueRight = right.eval(input) + if (valueRight != null) { + valueLeft match { + case l: Long => l >> valueRight.asInstanceOf[Integer] + case i: Integer => i >> valueRight.asInstanceOf[Integer] + case s: Short => s >> valueRight.asInstanceOf[Integer] + case b: Byte => b >> valueRight.asInstanceOf[Integer] + } + } else { + null + } + } else { + null + } + } + + override def dataType: DataType = { + left.dataType match { + case LongType => LongType + case IntegerType | ShortType | ByteType => IntegerType + case _ => NullType + } + } + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left >> $right;") + } +} + /** * Performs the inverse operation of HEX. * Resulting characters are returned as a byte array. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index b3345d7069159..aa27fe3cd5564 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.types.{DataType, DoubleType, LongType} +import org.apache.spark.sql.types.{IntegerType, DataType, DoubleType, LongType} class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -225,6 +225,32 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNull = true) } + test("shift left") { + checkEvaluation(ShiftLeft(Literal.create(null, IntegerType), Literal(1)), null) + checkEvaluation(ShiftLeft(Literal(21), Literal.create(null, IntegerType)), null) + checkEvaluation( + ShiftLeft(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null) + checkEvaluation(ShiftLeft(Literal(21), Literal(1)), 42) + checkEvaluation(ShiftLeft(Literal(21.toByte), Literal(1)), 42) + checkEvaluation(ShiftLeft(Literal(21.toShort), Literal(1)), 42) + checkEvaluation(ShiftLeft(Literal(21.toLong), Literal(1)), 42.toLong) + + checkEvaluation(ShiftLeft(Literal(-21.toLong), Literal(1)), -42.toLong) + } + + test("shift right") { + checkEvaluation(ShiftRight(Literal.create(null, IntegerType), Literal(1)), null) + checkEvaluation(ShiftRight(Literal(42), Literal.create(null, IntegerType)), null) + checkEvaluation( + ShiftRight(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null) + checkEvaluation(ShiftRight(Literal(42), Literal(1)), 21) + checkEvaluation(ShiftRight(Literal(42.toByte), Literal(1)), 21) + checkEvaluation(ShiftRight(Literal(42.toShort), Literal(1)), 21) + checkEvaluation(ShiftRight(Literal(42.toLong), Literal(1)), 21.toLong) + + checkEvaluation(ShiftRight(Literal(-42.toLong), Literal(1)), -21.toLong) + } + test("hex") { checkEvaluation(Hex(Literal(28)), "1C") checkEvaluation(Hex(Literal(-28)), "FFFFFFFFFFFFFFE4") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index e6f623bdf39eb..a5b68286853ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1298,6 +1298,44 @@ object functions { */ def rint(columnName: String): Column = rint(Column(columnName)) + /** + * Shift the the given value numBits left. If the given value is a long value, this function + * will return a long value else it will return an integer value. + * + * @group math_funcs + * @since 1.5.0 + */ + def shiftLeft(e: Column, numBits: Int): Column = ShiftLeft(e.expr, lit(numBits).expr) + + /** + * Shift the the given value numBits left. If the given value is a long value, this function + * will return a long value else it will return an integer value. + * + * @group math_funcs + * @since 1.5.0 + */ + def shiftLeft(columnName: String, numBits: Int): Column = + shiftLeft(Column(columnName), numBits) + + /** + * Shift the the given value numBits right. If the given value is a long value, it will return + * a long value else it will return an integer value. + * + * @group math_funcs + * @since 1.5.0 + */ + def shiftRight(e: Column, numBits: Int): Column = ShiftRight(e.expr, lit(numBits).expr) + + /** + * Shift the the given value numBits right. If the given value is a long value, it will return + * a long value else it will return an integer value. + * + * @group math_funcs + * @since 1.5.0 + */ + def shiftRight(columnName: String, numBits: Int): Column = + shiftRight(Column(columnName), numBits) + /** * Computes the signum of the given value. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index c03cde38d75d0..4c5696deaff81 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -259,6 +259,40 @@ class MathExpressionsSuite extends QueryTest { testOneToOneNonNegativeMathFunction(log1p, math.log1p) } + test("shift left") { + val df = Seq[(Long, Integer, Short, Byte, Integer, Integer)]((21, 21, 21, 21, 21, null)) + .toDF("a", "b", "c", "d", "e", "f") + + checkAnswer( + df.select( + shiftLeft('a, 1), shiftLeft('b, 1), shiftLeft('c, 1), shiftLeft('d, 1), + shiftLeft('f, 1)), + Row(42.toLong, 42, 42.toShort, 42.toByte, null)) + + checkAnswer( + df.selectExpr( + "shiftLeft(a, 1)", "shiftLeft(b, 1)", "shiftLeft(b, 1)", "shiftLeft(d, 1)", + "shiftLeft(f, 1)"), + Row(42.toLong, 42, 42.toShort, 42.toByte, null)) + } + + test("shift right") { + val df = Seq[(Long, Integer, Short, Byte, Integer, Integer)]((42, 42, 42, 42, 42, null)) + .toDF("a", "b", "c", "d", "e", "f") + + checkAnswer( + df.select( + shiftRight('a, 1), shiftRight('b, 1), shiftRight('c, 1), shiftRight('d, 1), + shiftRight('f, 1)), + Row(21.toLong, 21, 21.toShort, 21.toByte, null)) + + checkAnswer( + df.selectExpr( + "shiftRight(a, 1)", "shiftRight(b, 1)", "shiftRight(c, 1)", "shiftRight(d, 1)", + "shiftRight(f, 1)"), + Row(21.toLong, 21, 21.toShort, 21.toByte, null)) + } + test("binary log") { val df = Seq[(Integer, Integer)]((123, null)).toDF("a", "b") checkAnswer(