Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,19 @@ def shiftRight(col, numBits):
return Column(jc)


@since(1.5)
def shiftRightUnsigned(col, numBits):
"""Unsigned shift the the given value numBits right.

>>> sqlContext.createDataFrame([(-42,)], ['a']).select(shiftRightUnsigned('a', 1).alias('r'))\
.collect()
[Row(r=9223372036854775787)]
"""
sc = SparkContext._active_spark_context
jc = sc._jvm.functions.shiftRightUnsigned(_to_java_column(col), numBits)
return Column(jc)


@since(1.4)
def sparkPartitionId():
"""A column for partition ID of the Spark task.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ object FunctionRegistry {
expression[Rint]("rint"),
expression[ShiftLeft]("shiftleft"),
expression[ShiftRight]("shiftright"),
expression[ShiftRightUnsigned]("shiftrightunsigned"),
expression[Signum]("sign"),
expression[Signum]("signum"),
expression[Sin]("sin"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,55 @@ case class ShiftRight(left: Expression, right: Expression) extends BinaryExpress
}
}

case class ShiftRightUnsigned(left: Expression, right: Expression) extends BinaryExpression {

override def checkInputDataTypes(): TypeCheckResult = {
(left.dataType, right.dataType) match {
case (NullType, _) | (_, NullType) => return TypeCheckResult.TypeCheckSuccess
case (_, IntegerType) => left.dataType match {
case LongType | IntegerType | ShortType | ByteType =>
return TypeCheckResult.TypeCheckSuccess
case _ => // failed
}
case _ => // failed
}
TypeCheckResult.TypeCheckFailure(
s"ShiftRightUnsigned expects long, integer, short or byte value as first argument and an " +
s"integer value as second argument, not (${left.dataType}, ${right.dataType})")
}

override def eval(input: InternalRow): Any = {
val valueLeft = left.eval(input)
if (valueLeft != null) {
val valueRight = right.eval(input)
if (valueRight != null) {
valueLeft match {
case l: Long => l >>> valueRight.asInstanceOf[Integer]
case i: Integer => i >>> valueRight.asInstanceOf[Integer]
case s: Short => s >>> valueRight.asInstanceOf[Integer]
case b: Byte => b >>> valueRight.asInstanceOf[Integer]
}
} else {
null
}
} else {
null
}
}

override def dataType: DataType = {
left.dataType match {
case LongType => LongType
case IntegerType | ShortType | ByteType => IntegerType
case _ => NullType
}
}

override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left >>> $right;")
}
}

/**
* Performs the inverse operation of HEX.
* Resulting characters are returned as a byte array.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,19 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(ShiftRight(Literal(-42.toLong), Literal(1)), -21.toLong)
}

test("shift right unsigned") {
checkEvaluation(ShiftRightUnsigned(Literal.create(null, IntegerType), Literal(1)), null)
checkEvaluation(ShiftRightUnsigned(Literal(42), Literal.create(null, IntegerType)), null)
checkEvaluation(
ShiftRight(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null)
checkEvaluation(ShiftRightUnsigned(Literal(42), Literal(1)), 21)
checkEvaluation(ShiftRightUnsigned(Literal(42.toByte), Literal(1)), 21)
checkEvaluation(ShiftRightUnsigned(Literal(42.toShort), Literal(1)), 21)
checkEvaluation(ShiftRightUnsigned(Literal(42.toLong), Literal(1)), 21.toLong)

checkEvaluation(ShiftRightUnsigned(Literal(-42.toLong), Literal(1)), 9223372036854775787L)
}

test("hex") {
checkEvaluation(Hex(Literal(28)), "1C")
checkEvaluation(Hex(Literal(-28)), "FFFFFFFFFFFFFFE4")
Expand Down
20 changes: 20 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1343,6 +1343,26 @@ object functions {
*/
def shiftRight(e: Column, numBits: Int): Column = ShiftRight(e.expr, lit(numBits).expr)

/**
* Unsigned shift the the given value numBits right. If the given value is a long value,
* it will return a long value else it will return an integer value.
*
* @group math_funcs
* @since 1.5.0
*/
def shiftRightUnsigned(columnName: String, numBits: Int): Column =
shiftRightUnsigned(Column(columnName), numBits)

/**
* Unsigned shift the the given value numBits right. If the given value is a long value,
* it will return a long value else it will return an integer value.
*
* @group math_funcs
* @since 1.5.0
*/
def shiftRightUnsigned(e: Column, numBits: Int): Column =
ShiftRightUnsigned(e.expr, lit(numBits).expr)

/**
* Shift the the given value numBits right. If the given value is a long value, it will return
* a long value else it will return an integer value.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,23 @@ class MathExpressionsSuite extends QueryTest {
Row(21.toLong, 21, 21.toShort, 21.toByte, null))
}

test("shift right unsigned") {
val df = Seq[(Long, Integer, Short, Byte, Integer, Integer)]((-42, 42, 42, 42, 42, null))
.toDF("a", "b", "c", "d", "e", "f")

checkAnswer(
df.select(
shiftRightUnsigned('a, 1), shiftRightUnsigned('b, 1), shiftRightUnsigned('c, 1),
shiftRightUnsigned('d, 1), shiftRightUnsigned('f, 1)),
Row(9223372036854775787L, 21, 21.toShort, 21.toByte, null))

checkAnswer(
df.selectExpr(
"shiftRightUnsigned(a, 1)", "shiftRightUnsigned(b, 1)", "shiftRightUnsigned(c, 1)",
"shiftRightUnsigned(d, 1)", "shiftRightUnsigned(f, 1)"),
Row(9223372036854775787L, 21, 21.toShort, 21.toByte, null))
}

test("binary log") {
val df = Seq[(Integer, Integer)]((123, null)).toDF("a", "b")
checkAnswer(
Expand Down