Skip to content

Commit 305e77c

Browse files
zhichao-lirxin
authored andcommitted
[SPARK-8209[SQL]Add function conv
cc chenghao-intel adrian-wang Author: zhichao.li <[email protected]> Closes apache#6872 from zhichao-li/conv and squashes the following commits: 6ef3b37 [zhichao.li] add unittest and comments 78d9836 [zhichao.li] polish dataframe api and add unittest e2bace3 [zhichao.li] update to use ImplicitCastInputTypes cbcad3f [zhichao.li] add function conv
1 parent 59d24c2 commit 305e77c

File tree

5 files changed

+242
-2
lines changed

5 files changed

+242
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ object FunctionRegistry {
9999
expression[Ceil]("ceil"),
100100
expression[Ceil]("ceiling"),
101101
expression[Cos]("cos"),
102+
expression[Conv]("conv"),
102103
expression[EulerNumber]("e"),
103104
expression[Exp]("exp"),
104105
expression[Expm1]("expm1"),

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.catalyst.expressions
1919

2020
import java.{lang => jl}
21+
import java.util.Arrays
2122

2223
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2324
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckSuccess, TypeCheckFailure}
@@ -139,6 +140,196 @@ case class Cos(child: Expression) extends UnaryMathExpression(math.cos, "COS")
139140

140141
case class Cosh(child: Expression) extends UnaryMathExpression(math.cosh, "COSH")
141142

143+
/**
144+
* Convert a num from one base to another
145+
* @param numExpr the number to be converted
146+
* @param fromBaseExpr from which base
147+
* @param toBaseExpr to which base
148+
*/
149+
case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expression)
150+
extends Expression with ImplicitCastInputTypes{
151+
152+
override def foldable: Boolean = numExpr.foldable && fromBaseExpr.foldable && toBaseExpr.foldable
153+
154+
override def nullable: Boolean = numExpr.nullable || fromBaseExpr.nullable || toBaseExpr.nullable
155+
156+
override def children: Seq[Expression] = Seq(numExpr, fromBaseExpr, toBaseExpr)
157+
158+
override def inputTypes: Seq[AbstractDataType] = Seq(StringType, IntegerType, IntegerType)
159+
160+
/** Returns the result of evaluating this expression on a given input Row */
161+
override def eval(input: InternalRow): Any = {
162+
val num = numExpr.eval(input)
163+
val fromBase = fromBaseExpr.eval(input)
164+
val toBase = toBaseExpr.eval(input)
165+
if (num == null || fromBase == null || toBase == null) {
166+
null
167+
} else {
168+
conv(num.asInstanceOf[UTF8String].getBytes,
169+
fromBase.asInstanceOf[Int], toBase.asInstanceOf[Int])
170+
}
171+
}
172+
173+
/**
174+
* Returns the [[DataType]] of the result of evaluating this expression. It is
175+
* invalid to query the dataType of an unresolved expression (i.e., when `resolved` == false).
176+
*/
177+
override def dataType: DataType = StringType
178+
179+
private val value = new Array[Byte](64)
180+
181+
/**
182+
* Divide x by m as if x is an unsigned 64-bit integer. Examples:
183+
* unsignedLongDiv(-1, 2) == Long.MAX_VALUE unsignedLongDiv(6, 3) == 2
184+
* unsignedLongDiv(0, 5) == 0
185+
*
186+
* @param x is treated as unsigned
187+
* @param m is treated as signed
188+
*/
189+
private def unsignedLongDiv(x: Long, m: Int): Long = {
190+
if (x >= 0) {
191+
x / m
192+
} else {
193+
// Let uval be the value of the unsigned long with the same bits as x
194+
// Two's complement => x = uval - 2*MAX - 2
195+
// => uval = x + 2*MAX + 2
196+
// Now, use the fact: (a+b)/c = a/c + b/c + (a%c+b%c)/c
197+
(x / m + 2 * (Long.MaxValue / m) + 2 / m + (x % m + 2 * (Long.MaxValue % m) + 2 % m) / m)
198+
}
199+
}
200+
201+
/**
202+
* Decode v into value[].
203+
*
204+
* @param v is treated as an unsigned 64-bit integer
205+
* @param radix must be between MIN_RADIX and MAX_RADIX
206+
*/
207+
private def decode(v: Long, radix: Int): Unit = {
208+
var tmpV = v
209+
Arrays.fill(value, 0.asInstanceOf[Byte])
210+
var i = value.length - 1
211+
while (tmpV != 0) {
212+
val q = unsignedLongDiv(tmpV, radix)
213+
value(i) = (tmpV - q * radix).asInstanceOf[Byte]
214+
tmpV = q
215+
i -= 1
216+
}
217+
}
218+
219+
/**
220+
* Convert value[] into a long. On overflow, return -1 (as mySQL does). If a
221+
* negative digit is found, ignore the suffix starting there.
222+
*
223+
* @param radix must be between MIN_RADIX and MAX_RADIX
224+
* @param fromPos is the first element that should be conisdered
225+
* @return the result should be treated as an unsigned 64-bit integer.
226+
*/
227+
private def encode(radix: Int, fromPos: Int): Long = {
228+
var v: Long = 0L
229+
val bound = unsignedLongDiv(-1 - radix, radix) // Possible overflow once
230+
// val
231+
// exceeds this value
232+
var i = fromPos
233+
while (i < value.length && value(i) >= 0) {
234+
if (v >= bound) {
235+
// Check for overflow
236+
if (unsignedLongDiv(-1 - value(i), radix) < v) {
237+
return -1
238+
}
239+
}
240+
v = v * radix + value(i)
241+
i += 1
242+
}
243+
return v
244+
}
245+
246+
/**
247+
* Convert the bytes in value[] to the corresponding chars.
248+
*
249+
* @param radix must be between MIN_RADIX and MAX_RADIX
250+
* @param fromPos is the first nonzero element
251+
*/
252+
private def byte2char(radix: Int, fromPos: Int): Unit = {
253+
var i = fromPos
254+
while (i < value.length) {
255+
value(i) = Character.toUpperCase(Character.forDigit(value(i), radix)).asInstanceOf[Byte]
256+
i += 1
257+
}
258+
}
259+
260+
/**
261+
* Convert the chars in value[] to the corresponding integers. Convert invalid
262+
* characters to -1.
263+
*
264+
* @param radix must be between MIN_RADIX and MAX_RADIX
265+
* @param fromPos is the first nonzero element
266+
*/
267+
private def char2byte(radix: Int, fromPos: Int): Unit = {
268+
var i = fromPos
269+
while ( i < value.length) {
270+
value(i) = Character.digit(value(i), radix).asInstanceOf[Byte]
271+
i += 1
272+
}
273+
}
274+
275+
/**
276+
* Convert numbers between different number bases. If toBase>0 the result is
277+
* unsigned, otherwise it is signed.
278+
* NB: This logic is borrowed from org.apache.hadoop.hive.ql.ud.UDFConv
279+
*/
280+
private def conv(n: Array[Byte] , fromBase: Int, toBase: Int ): UTF8String = {
281+
if (n == null || fromBase == null || toBase == null || n.isEmpty) {
282+
return null
283+
}
284+
285+
if (fromBase < Character.MIN_RADIX || fromBase > Character.MAX_RADIX
286+
|| Math.abs(toBase) < Character.MIN_RADIX
287+
|| Math.abs(toBase) > Character.MAX_RADIX) {
288+
return null
289+
}
290+
291+
var (negative, first) = if (n(0) == '-') (true, 1) else (false, 0)
292+
293+
// Copy the digits in the right side of the array
294+
var i = 1
295+
while (i <= n.length - first) {
296+
value(value.length - i) = n(n.length - i)
297+
i += 1
298+
}
299+
char2byte(fromBase, value.length - n.length + first)
300+
301+
// Do the conversion by going through a 64 bit integer
302+
var v = encode(fromBase, value.length - n.length + first)
303+
if (negative && toBase > 0) {
304+
if (v < 0) {
305+
v = -1
306+
} else {
307+
v = -v
308+
}
309+
}
310+
if (toBase < 0 && v < 0) {
311+
v = -v
312+
negative = true
313+
}
314+
decode(v, Math.abs(toBase))
315+
316+
// Find the first non-zero digit or the last digits if all are zero.
317+
val firstNonZeroPos = {
318+
val firstNonZero = value.indexWhere( _ != 0)
319+
if (firstNonZero != -1) firstNonZero else value.length - 1
320+
}
321+
322+
byte2char(Math.abs(toBase), firstNonZeroPos)
323+
324+
var resultStartPos = firstNonZeroPos
325+
if (negative && toBase < 0) {
326+
resultStartPos = firstNonZeroPos - 1
327+
value(resultStartPos) = '-'
328+
}
329+
UTF8String.fromBytes( Arrays.copyOfRange(value, resultStartPos, value.length))
330+
}
331+
}
332+
142333
case class Exp(child: Expression) extends UnaryMathExpression(math.exp, "EXP")
143334

144335
case class Expm1(child: Expression) extends UnaryMathExpression(math.expm1, "EXPM1")

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,13 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20-
import scala.math.BigDecimal.RoundingMode
21-
2220
import com.google.common.math.LongMath
2321

2422
import org.apache.spark.SparkFunSuite
2523
import org.apache.spark.sql.catalyst.dsl.expressions._
2624
import org.apache.spark.sql.types._
2725

26+
2827
class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
2928

3029
/**
@@ -95,6 +94,24 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
9594
checkEvaluation(c(Literal(1.0), Literal.create(null, DoubleType)), null, create_row(null))
9695
}
9796

97+
test("conv") {
98+
checkEvaluation(Conv(Literal("3"), Literal(10), Literal(2)), "11")
99+
checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(-16)), "-F")
100+
checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(16)), "FFFFFFFFFFFFFFF1")
101+
checkEvaluation(Conv(Literal("big"), Literal(36), Literal(16)), "3A48")
102+
checkEvaluation(Conv(Literal(null), Literal(36), Literal(16)), null)
103+
checkEvaluation(Conv(Literal("3"), Literal(null), Literal(16)), null)
104+
checkEvaluation(
105+
Conv(Literal("1234"), Literal(10), Literal(37)), null)
106+
checkEvaluation(
107+
Conv(Literal(""), Literal(10), Literal(16)), null)
108+
checkEvaluation(
109+
Conv(Literal("9223372036854775807"), Literal(36), Literal(16)), "FFFFFFFFFFFFFFFF")
110+
// If there is an invalid digit in the number, the longest valid prefix should be converted.
111+
checkEvaluation(
112+
Conv(Literal("11abc"), Literal(10), Literal(16)), "B")
113+
}
114+
98115
test("e") {
99116
testLeaf(EulerNumber, math.E)
100117
}

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,24 @@ object functions {
6868
*/
6969
def column(colName: String): Column = Column(colName)
7070

71+
/**
72+
* Convert a number from one base to another for the specified expressions
73+
*
74+
* @group math_funcs
75+
* @since 1.5.0
76+
*/
77+
def conv(num: Column, fromBase: Int, toBase: Int): Column =
78+
Conv(num.expr, lit(fromBase).expr, lit(toBase).expr)
79+
80+
/**
81+
* Convert a number from one base to another for the specified expressions
82+
*
83+
* @group math_funcs
84+
* @since 1.5.0
85+
*/
86+
def conv(numColName: String, fromBase: Int, toBase: Int): Column =
87+
conv(Column(numColName), fromBase, toBase)
88+
7189
/**
7290
* Creates a [[Column]] of literal value.
7391
*

sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,19 @@ class MathExpressionsSuite extends QueryTest {
178178
Row(0.0, 1.0, 2.0))
179179
}
180180

181+
test("conv") {
182+
val df = Seq(("333", 10, 2)).toDF("num", "fromBase", "toBase")
183+
checkAnswer(df.select(conv('num, 10, 16)), Row("14D"))
184+
checkAnswer(df.select(conv("num", 10, 16)), Row("14D"))
185+
checkAnswer(df.select(conv(lit(100), 2, 16)), Row("4"))
186+
checkAnswer(df.select(conv(lit(3122234455L), 10, 16)), Row("BA198457"))
187+
checkAnswer(df.selectExpr("conv(num, fromBase, toBase)"), Row("101001101"))
188+
checkAnswer(df.selectExpr("""conv("100", 2, 10)"""), Row("4"))
189+
checkAnswer(df.selectExpr("""conv("-10", 16, -10)"""), Row("-16"))
190+
checkAnswer(
191+
df.selectExpr("""conv("9223372036854775807", 36, -16)"""), Row("-1")) // for overflow
192+
}
193+
181194
test("floor") {
182195
testOneToOneMathFunction(floor, math.floor)
183196
}

0 commit comments

Comments
 (0)