|
18 | 18 | package org.apache.spark.sql.catalyst.expressions |
19 | 19 |
|
20 | 20 | import java.{lang => jl} |
| 21 | +import java.util.Arrays |
21 | 22 |
|
22 | 23 | import org.apache.spark.sql.catalyst.analysis.TypeCheckResult |
23 | 24 | import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckSuccess, TypeCheckFailure} |
@@ -139,6 +140,196 @@ case class Cos(child: Expression) extends UnaryMathExpression(math.cos, "COS") |
139 | 140 |
|
140 | 141 | case class Cosh(child: Expression) extends UnaryMathExpression(math.cosh, "COSH") |
141 | 142 |
|
| 143 | +/** |
| 144 | + * Convert a num from one base to another |
| 145 | + * @param numExpr the number to be converted |
| 146 | + * @param fromBaseExpr from which base |
| 147 | + * @param toBaseExpr to which base |
| 148 | + */ |
| 149 | +case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expression) |
| 150 | + extends Expression with ImplicitCastInputTypes{ |
| 151 | + |
| 152 | + override def foldable: Boolean = numExpr.foldable && fromBaseExpr.foldable && toBaseExpr.foldable |
| 153 | + |
| 154 | + override def nullable: Boolean = numExpr.nullable || fromBaseExpr.nullable || toBaseExpr.nullable |
| 155 | + |
| 156 | + override def children: Seq[Expression] = Seq(numExpr, fromBaseExpr, toBaseExpr) |
| 157 | + |
| 158 | + override def inputTypes: Seq[AbstractDataType] = Seq(StringType, IntegerType, IntegerType) |
| 159 | + |
| 160 | + /** Returns the result of evaluating this expression on a given input Row */ |
| 161 | + override def eval(input: InternalRow): Any = { |
| 162 | + val num = numExpr.eval(input) |
| 163 | + val fromBase = fromBaseExpr.eval(input) |
| 164 | + val toBase = toBaseExpr.eval(input) |
| 165 | + if (num == null || fromBase == null || toBase == null) { |
| 166 | + null |
| 167 | + } else { |
| 168 | + conv(num.asInstanceOf[UTF8String].getBytes, |
| 169 | + fromBase.asInstanceOf[Int], toBase.asInstanceOf[Int]) |
| 170 | + } |
| 171 | + } |
| 172 | + |
| 173 | + /** |
| 174 | + * Returns the [[DataType]] of the result of evaluating this expression. It is |
| 175 | + * invalid to query the dataType of an unresolved expression (i.e., when `resolved` == false). |
| 176 | + */ |
| 177 | + override def dataType: DataType = StringType |
| 178 | + |
| 179 | + private val value = new Array[Byte](64) |
| 180 | + |
| 181 | + /** |
| 182 | + * Divide x by m as if x is an unsigned 64-bit integer. Examples: |
| 183 | + * unsignedLongDiv(-1, 2) == Long.MAX_VALUE unsignedLongDiv(6, 3) == 2 |
| 184 | + * unsignedLongDiv(0, 5) == 0 |
| 185 | + * |
| 186 | + * @param x is treated as unsigned |
| 187 | + * @param m is treated as signed |
| 188 | + */ |
| 189 | + private def unsignedLongDiv(x: Long, m: Int): Long = { |
| 190 | + if (x >= 0) { |
| 191 | + x / m |
| 192 | + } else { |
| 193 | + // Let uval be the value of the unsigned long with the same bits as x |
| 194 | + // Two's complement => x = uval - 2*MAX - 2 |
| 195 | + // => uval = x + 2*MAX + 2 |
| 196 | + // Now, use the fact: (a+b)/c = a/c + b/c + (a%c+b%c)/c |
| 197 | + (x / m + 2 * (Long.MaxValue / m) + 2 / m + (x % m + 2 * (Long.MaxValue % m) + 2 % m) / m) |
| 198 | + } |
| 199 | + } |
| 200 | + |
| 201 | + /** |
| 202 | + * Decode v into value[]. |
| 203 | + * |
| 204 | + * @param v is treated as an unsigned 64-bit integer |
| 205 | + * @param radix must be between MIN_RADIX and MAX_RADIX |
| 206 | + */ |
| 207 | + private def decode(v: Long, radix: Int): Unit = { |
| 208 | + var tmpV = v |
| 209 | + Arrays.fill(value, 0.asInstanceOf[Byte]) |
| 210 | + var i = value.length - 1 |
| 211 | + while (tmpV != 0) { |
| 212 | + val q = unsignedLongDiv(tmpV, radix) |
| 213 | + value(i) = (tmpV - q * radix).asInstanceOf[Byte] |
| 214 | + tmpV = q |
| 215 | + i -= 1 |
| 216 | + } |
| 217 | + } |
| 218 | + |
| 219 | + /** |
| 220 | + * Convert value[] into a long. On overflow, return -1 (as mySQL does). If a |
| 221 | + * negative digit is found, ignore the suffix starting there. |
| 222 | + * |
| 223 | + * @param radix must be between MIN_RADIX and MAX_RADIX |
| 224 | + * @param fromPos is the first element that should be conisdered |
| 225 | + * @return the result should be treated as an unsigned 64-bit integer. |
| 226 | + */ |
| 227 | + private def encode(radix: Int, fromPos: Int): Long = { |
| 228 | + var v: Long = 0L |
| 229 | + val bound = unsignedLongDiv(-1 - radix, radix) // Possible overflow once |
| 230 | + // val |
| 231 | + // exceeds this value |
| 232 | + var i = fromPos |
| 233 | + while (i < value.length && value(i) >= 0) { |
| 234 | + if (v >= bound) { |
| 235 | + // Check for overflow |
| 236 | + if (unsignedLongDiv(-1 - value(i), radix) < v) { |
| 237 | + return -1 |
| 238 | + } |
| 239 | + } |
| 240 | + v = v * radix + value(i) |
| 241 | + i += 1 |
| 242 | + } |
| 243 | + return v |
| 244 | + } |
| 245 | + |
| 246 | + /** |
| 247 | + * Convert the bytes in value[] to the corresponding chars. |
| 248 | + * |
| 249 | + * @param radix must be between MIN_RADIX and MAX_RADIX |
| 250 | + * @param fromPos is the first nonzero element |
| 251 | + */ |
| 252 | + private def byte2char(radix: Int, fromPos: Int): Unit = { |
| 253 | + var i = fromPos |
| 254 | + while (i < value.length) { |
| 255 | + value(i) = Character.toUpperCase(Character.forDigit(value(i), radix)).asInstanceOf[Byte] |
| 256 | + i += 1 |
| 257 | + } |
| 258 | + } |
| 259 | + |
| 260 | + /** |
| 261 | + * Convert the chars in value[] to the corresponding integers. Convert invalid |
| 262 | + * characters to -1. |
| 263 | + * |
| 264 | + * @param radix must be between MIN_RADIX and MAX_RADIX |
| 265 | + * @param fromPos is the first nonzero element |
| 266 | + */ |
| 267 | + private def char2byte(radix: Int, fromPos: Int): Unit = { |
| 268 | + var i = fromPos |
| 269 | + while ( i < value.length) { |
| 270 | + value(i) = Character.digit(value(i), radix).asInstanceOf[Byte] |
| 271 | + i += 1 |
| 272 | + } |
| 273 | + } |
| 274 | + |
| 275 | + /** |
| 276 | + * Convert numbers between different number bases. If toBase>0 the result is |
| 277 | + * unsigned, otherwise it is signed. |
| 278 | + * NB: This logic is borrowed from org.apache.hadoop.hive.ql.ud.UDFConv |
| 279 | + */ |
| 280 | + private def conv(n: Array[Byte] , fromBase: Int, toBase: Int ): UTF8String = { |
| 281 | + if (n == null || fromBase == null || toBase == null || n.isEmpty) { |
| 282 | + return null |
| 283 | + } |
| 284 | + |
| 285 | + if (fromBase < Character.MIN_RADIX || fromBase > Character.MAX_RADIX |
| 286 | + || Math.abs(toBase) < Character.MIN_RADIX |
| 287 | + || Math.abs(toBase) > Character.MAX_RADIX) { |
| 288 | + return null |
| 289 | + } |
| 290 | + |
| 291 | + var (negative, first) = if (n(0) == '-') (true, 1) else (false, 0) |
| 292 | + |
| 293 | + // Copy the digits in the right side of the array |
| 294 | + var i = 1 |
| 295 | + while (i <= n.length - first) { |
| 296 | + value(value.length - i) = n(n.length - i) |
| 297 | + i += 1 |
| 298 | + } |
| 299 | + char2byte(fromBase, value.length - n.length + first) |
| 300 | + |
| 301 | + // Do the conversion by going through a 64 bit integer |
| 302 | + var v = encode(fromBase, value.length - n.length + first) |
| 303 | + if (negative && toBase > 0) { |
| 304 | + if (v < 0) { |
| 305 | + v = -1 |
| 306 | + } else { |
| 307 | + v = -v |
| 308 | + } |
| 309 | + } |
| 310 | + if (toBase < 0 && v < 0) { |
| 311 | + v = -v |
| 312 | + negative = true |
| 313 | + } |
| 314 | + decode(v, Math.abs(toBase)) |
| 315 | + |
| 316 | + // Find the first non-zero digit or the last digits if all are zero. |
| 317 | + val firstNonZeroPos = { |
| 318 | + val firstNonZero = value.indexWhere( _ != 0) |
| 319 | + if (firstNonZero != -1) firstNonZero else value.length - 1 |
| 320 | + } |
| 321 | + |
| 322 | + byte2char(Math.abs(toBase), firstNonZeroPos) |
| 323 | + |
| 324 | + var resultStartPos = firstNonZeroPos |
| 325 | + if (negative && toBase < 0) { |
| 326 | + resultStartPos = firstNonZeroPos - 1 |
| 327 | + value(resultStartPos) = '-' |
| 328 | + } |
| 329 | + UTF8String.fromBytes( Arrays.copyOfRange(value, resultStartPos, value.length)) |
| 330 | + } |
| 331 | +} |
| 332 | + |
142 | 333 | case class Exp(child: Expression) extends UnaryMathExpression(math.exp, "EXP") |
143 | 334 |
|
144 | 335 | case class Expm1(child: Expression) extends UnaryMathExpression(math.expm1, "EXPM1") |
|
0 commit comments