From c852d46a9f133d785ff488ab938da3b269c2bb73 Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Tue, 30 Jun 2015 14:10:28 +0800 Subject: [PATCH 1/8] Add function unhex --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../spark/sql/catalyst/expressions/math.scala | 38 +++++++++++++++++++ .../expressions/MathFunctionsSuite.scala | 8 ++++ .../org/apache/spark/sql/functions.scala | 18 +++++++++ .../spark/sql/MathExpressionsSuite.scala | 10 +++++ 5 files changed, 75 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index b17457d3094c2..3e1f405f34cb6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -156,6 +156,7 @@ object FunctionRegistry { expression[Substring]("substr"), expression[Substring]("substring"), expression[Upper]("ucase"), + expression[UnHex]("unhex"), expression[Upper]("upper") ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index a022f3727bd58..725a60de64ef6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -354,6 +354,44 @@ case class Pow(left: Expression, right: Expression) } } +/** + * Performs the inverse operation of HEX. + * Resulting characters are returned as a byte array. + */ +case class UnHex(child: Expression) + extends UnaryExpression with ExpectsInputTypes with Serializable { + + override def expectedChildTypes: Seq[DataType] = Seq(StringType) + + override def dataType: DataType = BinaryType + + override def eval(input: InternalRow): Any = { + val num = child.eval(input) + if (num == null) { + null + } else { + unhex(num.asInstanceOf[UTF8String].toString) + } + } + + private def unhex(s: String): Array[Byte] = { + // append a leading 0 if needed + val str = if (s.length % 2 == 1) {"0" + s} else {s} + val result = new Array[Byte](str.length / 2) + var i = 0 + while (i < str.length()) { + try { + result(i / 2) = Integer.parseInt(str.substring(i, i + 2), 16).asInstanceOf[Byte] + } catch { + // invalid character present, return null + case _: NumberFormatException => return null + } + i += 2 + } + result + } +} + case class Hypot(left: Expression, right: Expression) extends BinaryMathExpression(math.hypot, "HYPOT") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index b932d4ab850c7..d0b46c1791573 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -238,6 +238,14 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { // scalastyle:on } + test("unhex") { + checkEvaluation(UnHex(Literal("737472696E67")), "string".getBytes) + // scalastyle:off + // Turn off scala style for non-ascii chars + checkEvaluation(UnHex(Literal("E4B889E9878DE79A84")), "三重的".getBytes) + // scalastyle:on + } + test("hypot") { testBinary(Hypot, math.hypot) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 4d9a019058228..0352bb75d01b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1062,6 +1062,24 @@ object functions { */ def hex(colName: String): Column = hex(Column(colName)) + /** + * Inverse of hex. Interprets each pair of characters as a hexadecimal number + * and converts to the byte representation of number. + * + * @group math_funcs + * @since 1.5.0 + */ + def unhex(column: Column): Column = UnHex(column.expr) + + /** + * Inverse of hex. Interprets each pair of characters as a hexadecimal number + * and converts to the byte representation of number. + * + * @group math_funcs + * @since 1.5.0 + */ + def unhex(colName: String): Column = unhex(Column(colName)) + /** * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index d6331aa4ff09e..1dfe5235debf9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -225,6 +225,16 @@ class MathExpressionsSuite extends QueryTest { checkAnswer(data.selectExpr("hex(cast(d as binary))"), Seq(Row("68656C6C6F"))) } + test("unhex") { + val data = Seq(("1C", "737472696E67")).toDF("a", "b") + checkAnswer(data.select(unhex('a)), Row(Array[Byte](28.toByte))) + checkAnswer(data.select(unhex('b)), Row("string".getBytes)) + checkAnswer(data.selectExpr("unhex(a)"), Row(Array[Byte](28.toByte))) + checkAnswer(data.selectExpr("unhex(b)"), Row("string".getBytes)) + checkAnswer(data.selectExpr("""unhex("##")"""), Row(null)) + + } + test("hypot") { testTwoToOneMathFunction(hypot, hypot, math.hypot) } From 11945c768b9b7de9b44e5cafeef66f0e8f702f52 Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Tue, 30 Jun 2015 14:42:01 +0800 Subject: [PATCH 2/8] style --- .../scala/org/apache/spark/sql/catalyst/expressions/math.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 725a60de64ef6..679958c99efcc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -373,7 +373,7 @@ case class UnHex(child: Expression) unhex(num.asInstanceOf[UTF8String].toString) } } - + private def unhex(s: String): Array[Byte] = { // append a leading 0 if needed val str = if (s.length % 2 == 1) {"0" + s} else {s} From cde73f57600bc83f89c9614bfdc690bf5e970895 Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Tue, 30 Jun 2015 14:57:16 +0800 Subject: [PATCH 3/8] update to use AutoCastInputTypes --- .../scala/org/apache/spark/sql/catalyst/expressions/math.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 679958c99efcc..8572a941aa69c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -359,7 +359,7 @@ case class Pow(left: Expression, right: Expression) * Resulting characters are returned as a byte array. */ case class UnHex(child: Expression) - extends UnaryExpression with ExpectsInputTypes with Serializable { + extends UnaryExpression with AutoCastInputTypes with Serializable { override def expectedChildTypes: Seq[DataType] = Seq(StringType) From bffd37f27c37087c64a0dd201d217e623b6ab7ee Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Tue, 30 Jun 2015 19:42:30 -0700 Subject: [PATCH 4/8] change to use Hex in apache common package --- .../spark/sql/catalyst/expressions/math.scala | 23 +++++++------------ .../expressions/MathFunctionsSuite.scala | 5 +--- .../spark/sql/MathExpressionsSuite.scala | 1 - 3 files changed, 9 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 8572a941aa69c..8fdd37f80e17c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -18,8 +18,10 @@ package org.apache.spark.sql.catalyst.expressions import java.lang.{Long => JLong} +import java.nio.charset.{StandardCharsets, Charset} import java.util.Arrays +import org.apache.commons.codec.DecoderException import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ @@ -370,25 +372,16 @@ case class UnHex(child: Expression) if (num == null) { null } else { - unhex(num.asInstanceOf[UTF8String].toString) + unhex(num.asInstanceOf[UTF8String]) } } - private def unhex(s: String): Array[Byte] = { - // append a leading 0 if needed - val str = if (s.length % 2 == 1) {"0" + s} else {s} - val result = new Array[Byte](str.length / 2) - var i = 0 - while (i < str.length()) { - try { - result(i / 2) = Integer.parseInt(str.substring(i, i + 2), 16).asInstanceOf[Byte] - } catch { - // invalid character present, return null - case _: NumberFormatException => return null - } - i += 2 + private def unhex(utf8Str: UTF8String): Array[Byte] = { + try { + new org.apache.commons.codec.binary.Hex(StandardCharsets.UTF_8).decode(utf8Str.getBytes) + } catch { + case _: DecoderException => null } - result } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index d0b46c1791573..9213cdb3e515f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -240,10 +240,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("unhex") { checkEvaluation(UnHex(Literal("737472696E67")), "string".getBytes) - // scalastyle:off - // Turn off scala style for non-ascii chars - checkEvaluation(UnHex(Literal("E4B889E9878DE79A84")), "三重的".getBytes) - // scalastyle:on + checkEvaluation(UnHex(Literal("")), new Array[Byte](0)) } test("hypot") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 1dfe5235debf9..991b6c104d666 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -232,7 +232,6 @@ class MathExpressionsSuite extends QueryTest { checkAnswer(data.selectExpr("unhex(a)"), Row(Array[Byte](28.toByte))) checkAnswer(data.selectExpr("unhex(b)"), Row("string".getBytes)) checkAnswer(data.selectExpr("""unhex("##")"""), Row(null)) - } test("hypot") { From 607d7a3a3aa51795965df5bd97f1603c3f3d668a Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Tue, 30 Jun 2015 23:45:48 -0700 Subject: [PATCH 5/8] use checkInputTypes --- .../spark/sql/catalyst/expressions/math.scala | 36 +++++++++++++------ 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 8fdd37f80e17c..80d30aa2a2e92 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -18,10 +18,8 @@ package org.apache.spark.sql.catalyst.expressions import java.lang.{Long => JLong} -import java.nio.charset.{StandardCharsets, Charset} import java.util.Arrays -import org.apache.commons.codec.DecoderException import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ @@ -361,27 +359,43 @@ case class Pow(left: Expression, right: Expression) * Resulting characters are returned as a byte array. */ case class UnHex(child: Expression) - extends UnaryExpression with AutoCastInputTypes with Serializable { - - override def expectedChildTypes: Seq[DataType] = Seq(StringType) + extends UnaryExpression with Serializable { override def dataType: DataType = BinaryType + override def checkInputDataTypes(): TypeCheckResult = { + if (child.dataType.isInstanceOf[StringType] || child.dataType == NullType) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure(s"$unHex accepts String type, not ${child.dataType}") + } + } + + override def eval(input: InternalRow): Any = { val num = child.eval(input) if (num == null) { null } else { - unhex(num.asInstanceOf[UTF8String]) + unhex(num.asInstanceOf[UTF8String].toString) } } - private def unhex(utf8Str: UTF8String): Array[Byte] = { - try { - new org.apache.commons.codec.binary.Hex(StandardCharsets.UTF_8).decode(utf8Str.getBytes) - } catch { - case _: DecoderException => null + private def unhex(s: String): Array[Byte] = { + // append a leading 0 if needed + val str = if (s.length % 2 == 1) {"0" + s} else {s} + val result = new Array[Byte](str.length / 2) + var i = 0 + while (i < str.length()) { + try { + result(i / 2) = Integer.parseInt(str.substring(i, i + 2), 16).asInstanceOf[Byte] + } catch { + // invalid character present, return null + case _: NumberFormatException => return null + } + i += 2 } + result } } From fe5c14a7ad299a4729af88af490c7e8395f996df Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Wed, 1 Jul 2015 05:58:45 -0700 Subject: [PATCH 6/8] add todigit --- .../spark/sql/catalyst/expressions/math.scala | 44 +++++++++++++------ .../expressions/MathFunctionsSuite.scala | 1 + 2 files changed, 31 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 80d30aa2a2e92..7b6f1a44d2ad8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -358,8 +358,7 @@ case class Pow(left: Expression, right: Expression) * Performs the inverse operation of HEX. * Resulting characters are returned as a byte array. */ -case class UnHex(child: Expression) - extends UnaryExpression with Serializable { +case class UnHex(child: Expression) extends UnaryExpression with Serializable { override def dataType: DataType = BinaryType @@ -367,35 +366,52 @@ case class UnHex(child: Expression) if (child.dataType.isInstanceOf[StringType] || child.dataType == NullType) { TypeCheckResult.TypeCheckSuccess } else { - TypeCheckResult.TypeCheckFailure(s"$unHex accepts String type, not ${child.dataType}") + TypeCheckResult.TypeCheckFailure(s"unHex accepts String type, not ${child.dataType}") } } - override def eval(input: InternalRow): Any = { val num = child.eval(input) if (num == null) { null } else { - unhex(num.asInstanceOf[UTF8String].toString) + unhex(num.asInstanceOf[UTF8String].getBytes) } } - private def unhex(s: String): Array[Byte] = { - // append a leading 0 if needed - val str = if (s.length % 2 == 1) {"0" + s} else {s} - val result = new Array[Byte](str.length / 2) + private val hexDigits = { + val array = Array.fill[Byte](128)(-1) + (0 to 9).foreach(i => array('0' + i) = i.toByte) + (0 to 5).foreach(i => array('A' + i) = (i + 10).toByte) + (0 to 5).foreach(i => array('a' + i) = (i + 10).toByte) + array + } + + private def toDigit(b: Byte): Byte = { + val digit = hexDigits(b) + if (digit == -1) { + throw new NumberFormatException(s"invalid hex number $b") + } + digit + } + + private def unhex(inputBytes: Array[Byte]): Array[Byte] = { + var bytes = inputBytes + if ((bytes.length & 0x01) != 0) { + bytes = 48.toByte +: bytes // padding with '0' + } + val out = new Array[Byte](bytes.length >> 1) + // two characters form the hex value. var i = 0 - while (i < str.length()) { + while (i < bytes.length) { try { - result(i / 2) = Integer.parseInt(str.substring(i, i + 2), 16).asInstanceOf[Byte] + out(i / 2) = ((toDigit(bytes(i)) << 4) | toDigit(bytes(i + 1)) & 0xFF).toByte + i += 2 } catch { - // invalid character present, return null case _: NumberFormatException => return null } - i += 2 } - result + out } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 9213cdb3e515f..b3345d7069159 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -241,6 +241,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("unhex") { checkEvaluation(UnHex(Literal("737472696E67")), "string".getBytes) checkEvaluation(UnHex(Literal("")), new Array[Byte](0)) + checkEvaluation(UnHex(Literal("0")), Array[Byte](0)) } test("hypot") { From a4ae6dc9bcf21ef51a3aab9b0b3a362a9a0a2d2b Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Wed, 1 Jul 2015 21:11:12 +0800 Subject: [PATCH 7/8] add udf_unhex to whitelist --- .../apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index f88e62763ca70..415a81644c58f 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -949,6 +949,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_trim", "udf_ucase", "udf_unix_timestamp", + "udf_unhex", "udf_upper", "udf_var_pop", "udf_var_samp", From 379356e42f9a10e2ae5bdd302794de121d3a958e Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Thu, 2 Jul 2015 09:23:01 +0800 Subject: [PATCH 8/8] remove exception checking --- .../spark/sql/catalyst/expressions/math.scala | 21 ++++++------------- .../spark/sql/MathExpressionsSuite.scala | 1 + 2 files changed, 7 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 7b6f1a44d2ad8..66491f4c875e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -379,7 +379,7 @@ case class UnHex(child: Expression) extends UnaryExpression with Serializable { } } - private val hexDigits = { + private val unhexDigits = { val array = Array.fill[Byte](128)(-1) (0 to 9).foreach(i => array('0' + i) = i.toByte) (0 to 5).foreach(i => array('A' + i) = (i + 10).toByte) @@ -387,29 +387,20 @@ case class UnHex(child: Expression) extends UnaryExpression with Serializable { array } - private def toDigit(b: Byte): Byte = { - val digit = hexDigits(b) - if (digit == -1) { - throw new NumberFormatException(s"invalid hex number $b") - } - digit - } - private def unhex(inputBytes: Array[Byte]): Array[Byte] = { var bytes = inputBytes if ((bytes.length & 0x01) != 0) { - bytes = 48.toByte +: bytes // padding with '0' + bytes = '0'.toByte +: bytes } val out = new Array[Byte](bytes.length >> 1) // two characters form the hex value. var i = 0 while (i < bytes.length) { - try { - out(i / 2) = ((toDigit(bytes(i)) << 4) | toDigit(bytes(i + 1)) & 0xFF).toByte + val first = unhexDigits(bytes(i)) + val second = unhexDigits(bytes(i + 1)) + if (first == -1 || second == -1) { return null} + out(i / 2) = (((first << 4) | second) & 0xFF).toByte i += 2 - } catch { - case _: NumberFormatException => return null - } } out } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 991b6c104d666..c03cde38d75d0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -232,6 +232,7 @@ class MathExpressionsSuite extends QueryTest { checkAnswer(data.selectExpr("unhex(a)"), Row(Array[Byte](28.toByte))) checkAnswer(data.selectExpr("unhex(b)"), Row("string".getBytes)) checkAnswer(data.selectExpr("""unhex("##")"""), Row(null)) + checkAnswer(data.selectExpr("""unhex("G123")"""), Row(null)) } test("hypot") {