Skip to content

Commit d85ae0b

Browse files
committed
add shiftrightunsigned
1 parent a59d14f commit d85ae0b

File tree

6 files changed

+111
-0
lines changed

6 files changed

+111
-0
lines changed

python/pyspark/sql/functions.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,17 @@ def shiftRight(col, numBits):
435435
jc = sc._jvm.functions.shiftRight(_to_java_column(col), numBits)
436436
return Column(jc)
437437

438+
@since(1.5)
439+
def shiftRightUnsigned(col, numBits):
440+
"""Unsigned shift the the given value numBits right.
441+
442+
>>> sqlContext.createDataFrame([(-42,)], ['a']).select(shiftRightUnsigned('a', 1).alias('r'))\
443+
.collect()
444+
[Row(r=9223372036854775787)]
445+
"""
446+
sc = SparkContext._active_spark_context
447+
jc = sc._jvm.functions.shiftRightUnsigned(_to_java_column(col), numBits)
448+
return Column(jc)
438449

439450
@since(1.4)
440451
def sparkPartitionId():

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ object FunctionRegistry {
129129
expression[Rint]("rint"),
130130
expression[ShiftLeft]("shiftleft"),
131131
expression[ShiftRight]("shiftright"),
132+
expression[ShiftRightUnsigned]("shiftrightunsigned"),
132133
expression[Signum]("sign"),
133134
expression[Signum]("signum"),
134135
expression[Sin]("sin"),

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,55 @@ case class ShiftRight(left: Expression, right: Expression) extends BinaryExpress
521521
}
522522
}
523523

524+
case class ShiftRightUnsigned(left: Expression, right: Expression) extends BinaryExpression {
525+
526+
override def checkInputDataTypes(): TypeCheckResult = {
527+
(left.dataType, right.dataType) match {
528+
case (NullType, _) | (_, NullType) => return TypeCheckResult.TypeCheckSuccess
529+
case (_, IntegerType) => left.dataType match {
530+
case LongType | IntegerType | ShortType | ByteType =>
531+
return TypeCheckResult.TypeCheckSuccess
532+
case _ => // failed
533+
}
534+
case _ => // failed
535+
}
536+
TypeCheckResult.TypeCheckFailure(
537+
s"ShiftRightUnsigned expects long, integer, short or byte value as first argument and an " +
538+
s"integer value as second argument, not (${left.dataType}, ${right.dataType})")
539+
}
540+
541+
override def eval(input: InternalRow): Any = {
542+
val valueLeft = left.eval(input)
543+
if (valueLeft != null) {
544+
val valueRight = right.eval(input)
545+
if (valueRight != null) {
546+
left.dataType match {
547+
case LongType => valueLeft.asInstanceOf[Long] >>> valueRight.asInstanceOf[Int]
548+
case IntegerType => valueLeft.asInstanceOf[Int] >>> valueRight.asInstanceOf[Int]
549+
case ShortType => valueLeft.asInstanceOf[Short] >>> valueRight.asInstanceOf[Int]
550+
case ByteType => valueLeft.asInstanceOf[Byte] >>> valueRight.asInstanceOf[Int]
551+
}
552+
} else {
553+
null
554+
}
555+
} else {
556+
null
557+
}
558+
}
559+
560+
override def dataType: DataType = {
561+
left.dataType match {
562+
case LongType => LongType
563+
case IntegerType | ShortType | ByteType => IntegerType
564+
case _ => NullType
565+
}
566+
}
567+
568+
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
569+
nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left >>> $right;")
570+
}
571+
}
572+
524573
/**
525574
* Performs the inverse operation of HEX.
526575
* Resulting characters are returned as a byte array.

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,19 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
264264
checkEvaluation(ShiftRight(Literal(-42.toLong), Literal(1)), -21.toLong)
265265
}
266266

267+
test("shift right unsigned") {
268+
checkEvaluation(ShiftRightUnsigned(Literal.create(null, IntegerType), Literal(1)), null)
269+
checkEvaluation(ShiftRightUnsigned(Literal(42), Literal.create(null, IntegerType)), null)
270+
checkEvaluation(
271+
ShiftRight(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null)
272+
checkEvaluation(ShiftRightUnsigned(Literal(42), Literal(1)), 21)
273+
checkEvaluation(ShiftRightUnsigned(Literal(42.toByte), Literal(1)), 21)
274+
checkEvaluation(ShiftRightUnsigned(Literal(42.toShort), Literal(1)), 21)
275+
checkEvaluation(ShiftRightUnsigned(Literal(42.toLong), Literal(1)), 21.toLong)
276+
277+
checkEvaluation(ShiftRightUnsigned(Literal(-42.toLong), Literal(1)), 9223372036854775787L)
278+
}
279+
267280
test("hex") {
268281
checkEvaluation(Hex(Literal(28)), "1C")
269282
checkEvaluation(Hex(Literal(-28)), "FFFFFFFFFFFFFFE4")

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1343,6 +1343,26 @@ object functions {
13431343
*/
13441344
def shiftRight(e: Column, numBits: Int): Column = ShiftRight(e.expr, lit(numBits).expr)
13451345

1346+
/**
1347+
* Unsigned shift the the given value numBits right. If the given value is a long value,
1348+
* it will return a long value else it will return an integer value.
1349+
*
1350+
* @group math_funcs
1351+
* @since 1.5.0
1352+
*/
1353+
def shiftRightUnsigned(columnName: String, numBits: Int): Column =
1354+
shiftRightUnsigned(Column(columnName), numBits)
1355+
1356+
/**
1357+
* Unsigned shift the the given value numBits right. If the given value is a long value,
1358+
* it will return a long value else it will return an integer value.
1359+
*
1360+
* @group math_funcs
1361+
* @since 1.5.0
1362+
*/
1363+
def shiftRightUnsigned(e: Column, numBits: Int): Column =
1364+
ShiftRightUnsigned(e.expr, lit(numBits).expr)
1365+
13461366
/**
13471367
* Shift the the given value numBits right. If the given value is a long value, it will return
13481368
* a long value else it will return an integer value.

sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,23 @@ class MathExpressionsSuite extends QueryTest {
304304
Row(21.toLong, 21, 21.toShort, 21.toByte, null))
305305
}
306306

307+
test("shift right unsigned") {
308+
val df = Seq[(Long, Integer, Short, Byte, Integer, Integer)]((-42, 42, 42, 42, 42, null))
309+
.toDF("a", "b", "c", "d", "e", "f")
310+
311+
checkAnswer(
312+
df.select(
313+
shiftRightUnsigned('a, 1), shiftRightUnsigned('b, 1), shiftRightUnsigned('c, 1),
314+
shiftRightUnsigned('d, 1), shiftRightUnsigned('f, 1)),
315+
Row(9223372036854775787L, 21, 21.toShort, 21.toByte, null))
316+
317+
checkAnswer(
318+
df.selectExpr(
319+
"shiftRightUnsigned(a, 1)", "shiftRightUnsigned(b, 1)", "shiftRightUnsigned(c, 1)",
320+
"shiftRightUnsigned(d, 1)", "shiftRightUnsigned(f, 1)"),
321+
Row(9223372036854775787L, 21, 21.toShort, 21.toByte, null))
322+
}
323+
307324
test("binary log") {
308325
val df = Seq[(Integer, Integer)]((123, null)).toDF("a", "b")
309326
checkAnswer(

0 commit comments

Comments
 (0)