Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1294,6 +1294,52 @@ public boolean toByte(IntWrapper intWrapper) {
return false;
}

/**
* Parses UTF8String(trimmed if needed) to long. This method is used when ANSI is enabled.
*
* @return If string contains valid numeric value then it returns the long value otherwise a
* NumberFormatException is thrown.
*/
public long toLongExact() {
LongWrapper result = new LongWrapper();
if (toLong(result)) {
return result.value;
}
throw new NumberFormatException("invalid input syntax for type numeric: " + this);
}

/**
* Parses UTF8String(trimmed if needed) to int. This method is used when ANSI is enabled.
*
* @return If string contains valid numeric value then it returns the int value otherwise a
* NumberFormatException is thrown.
*/
public int toIntExact() {
IntWrapper result = new IntWrapper();
if (toInt(result)) {
return result.value;
}
throw new NumberFormatException("invalid input syntax for type numeric: " + this);
}

public short toShortExact() {
int value = this.toIntExact();
short result = (short) value;
if (result == value) {
return result;
}
throw new NumberFormatException("invalid input syntax for type numeric: " + this);
}

public byte toByteExact() {
int value = this.toIntExact();
byte result = (byte) value;
if (result == value) {
return result;
}
throw new NumberFormatException("invalid input syntax for type numeric: " + this);
}

@Override
public String toString() {
return new String(getBytes(), StandardCharsets.UTF_8);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,8 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit

// LongConverter
private[this] def castToLong(from: DataType): Any => Any = from match {
case StringType if ansiEnabled =>
buildCast[UTF8String](_, _.toLongExact())
case StringType =>
val result = new LongWrapper()
buildCast[UTF8String](_, s => if (s.toLong(result)) result.value else null)
Expand All @@ -499,6 +501,8 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit

// IntConverter
private[this] def castToInt(from: DataType): Any => Any = from match {
case StringType if ansiEnabled =>
buildCast[UTF8String](_, _.toIntExact())
case StringType =>
val result = new IntWrapper()
buildCast[UTF8String](_, s => if (s.toInt(result)) result.value else null)
Expand All @@ -518,6 +522,8 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit

// ShortConverter
private[this] def castToShort(from: DataType): Any => Any = from match {
case StringType if ansiEnabled =>
buildCast[UTF8String](_, _.toShortExact())
case StringType =>
val result = new IntWrapper()
buildCast[UTF8String](_, s => if (s.toShort(result)) {
Expand Down Expand Up @@ -559,6 +565,8 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit

// ByteConverter
private[this] def castToByte(from: DataType): Any => Any = from match {
case StringType if ansiEnabled =>
buildCast[UTF8String](_, _.toByteExact())
case StringType =>
val result = new IntWrapper()
buildCast[UTF8String](_, s => if (s.toByte(result)) {
Expand Down Expand Up @@ -636,7 +644,12 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
// Please refer to https://github.com/apache/spark/pull/26640
changePrecision(Decimal(new JavaBigDecimal(s.toString.trim)), target)
} catch {
case _: NumberFormatException => null
case _: NumberFormatException =>
if (ansiEnabled) {
throw new NumberFormatException(s"invalid input syntax for type numeric: $s")
} else {
null
}
})
case BooleanType =>
buildCast[Boolean](_, b => toPrecision(if (b) Decimal.ONE else Decimal.ZERO, target))
Expand Down Expand Up @@ -664,7 +677,12 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
val doubleStr = s.toString
try doubleStr.toDouble catch {
case _: NumberFormatException =>
Cast.processFloatingPointSpecialLiterals(doubleStr, false)
val d = Cast.processFloatingPointSpecialLiterals(doubleStr, false)
if(ansiEnabled && d == null) {
throw new NumberFormatException(s"invalid input syntax for type numeric: $s")
} else {
d
}
}
})
case BooleanType =>
Expand All @@ -684,7 +702,12 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
val floatStr = s.toString
try floatStr.toFloat catch {
case _: NumberFormatException =>
Cast.processFloatingPointSpecialLiterals(floatStr, true)
val f = Cast.processFloatingPointSpecialLiterals(floatStr, true)
if (ansiEnabled && f == null) {
throw new NumberFormatException(s"invalid input syntax for type numeric: $s")
} else {
f
}
}
})
case BooleanType =>
Expand Down Expand Up @@ -1128,12 +1151,17 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
from match {
case StringType =>
(c, evPrim, evNull) =>
val handleException = if (ansiEnabled) {
s"""throw new NumberFormatException("invalid input syntax for type numeric: $c");"""
} else {
s"$evNull =true;"
}
code"""
try {
Decimal $tmp = Decimal.apply(new java.math.BigDecimal($c.toString().trim()));
${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast)}
} catch (java.lang.NumberFormatException e) {
$evNull = true;
$handleException
}
"""
case BooleanType =>
Expand Down Expand Up @@ -1355,6 +1383,8 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
}

private[this] def castToByteCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
case StringType if ansiEnabled =>
(c, evPrim, evNull) => code"$evPrim = $c.toByteExact();"
case StringType =>
val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper])
(c, evPrim, evNull) =>
Expand Down Expand Up @@ -1386,6 +1416,8 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
private[this] def castToShortCode(
from: DataType,
ctx: CodegenContext): CastFunction = from match {
case StringType if ansiEnabled =>
(c, evPrim, evNull) => code"$evPrim = $c.toShortExact();"
case StringType =>
val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper])
(c, evPrim, evNull) =>
Expand Down Expand Up @@ -1415,6 +1447,8 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
}

private[this] def castToIntCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
case StringType if ansiEnabled =>
(c, evPrim, evNull) => code"$evPrim = $c.toIntExact();"
case StringType =>
val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper])
(c, evPrim, evNull) =>
Expand Down Expand Up @@ -1443,9 +1477,10 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
}

private[this] def castToLongCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
case StringType if ansiEnabled =>
(c, evPrim, evNull) => code"$evPrim = $c.toLongExact();"
case StringType =>
val wrapper = ctx.freshVariable("longWrapper", classOf[UTF8String.LongWrapper])

(c, evPrim, evNull) =>
code"""
UTF8String.LongWrapper $wrapper = new UTF8String.LongWrapper();
Expand Down Expand Up @@ -1476,14 +1511,19 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
case StringType =>
val floatStr = ctx.freshVariable("floatStr", StringType)
(c, evPrim, evNull) =>
val handleNull = if (ansiEnabled) {
s"""throw new NumberFormatException("invalid input syntax for type numeric: $c");"""
} else {
s"$evNull = true;"
}
code"""
final String $floatStr = $c.toString();
try {
$evPrim = Float.valueOf($floatStr);
} catch (java.lang.NumberFormatException e) {
final Float f = (Float) Cast.processFloatingPointSpecialLiterals($floatStr, true);
if (f == null) {
$evNull = true;
$handleNull
} else {
$evPrim = f.floatValue();
}
Expand All @@ -1507,14 +1547,19 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
case StringType =>
val doubleStr = ctx.freshVariable("doubleStr", StringType)
(c, evPrim, evNull) =>
val handleNull = if (ansiEnabled) {
s"""throw new NumberFormatException("invalid input syntax for type numeric: $c");"""
} else {
s"$evNull = true;"
}
code"""
final String $doubleStr = $c.toString();
try {
$evPrim = Double.valueOf($doubleStr);
} catch (java.lang.NumberFormatException e) {
final Double d = (Double) Cast.processFloatingPointSpecialLiterals($doubleStr, false);
if (d == null) {
$evNull = true;
$handleNull
} else {
$evPrim = d.doubleValue();
}
Expand Down
Loading