From 3ed7795b40c6e5c81dacd7a69d23c0372fbd253c Mon Sep 17 00:00:00 2001 From: root1 Date: Wed, 18 Dec 2019 11:34:10 +0530 Subject: [PATCH 01/15] Cast to Decimal --- .../catalyst/analysis/ANSISQLStandard.scala | 45 +++++++ .../sql/catalyst/analysis/Analyzer.scala | 1 + .../ANSISQL/AnsiSqlCastToDecimal.scala | 111 ++++++++++++++++++ .../spark/sql/catalyst/expressions/Cast.scala | 12 +- 4 files changed, 163 insertions(+), 6 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ANSISQLStandard.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ANSISQL/AnsiSqlCastToDecimal.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ANSISQLStandard.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ANSISQLStandard.scala new file mode 100644 index 0000000000000..c8cb68de7e4eb --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ANSISQLStandard.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions.ANSISQL.AnsiSqlCastToDecimal +import org.apache.spark.sql.catalyst.expressions.Cast +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.DecimalType + +object ANSISQLStandard { + val ANSIStandardCastRules: Seq[Rule[LogicalPlan]] = Seq(ANSICast) + + object ANSICast extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = { + if (SQLConf.get.ansiEnabled) { + plan.transformExpressions { + case Cast(child, dataType, timeZoneId) + if child.dataType != DecimalType => + dataType match { + case _: DecimalType => AnsiSqlCastToDecimal(child, timeZoneId) + } + } + } else { + plan + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 2438ef9218224..868d445edb5ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -235,6 +235,7 @@ class Analyzer( Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*), Batch("Remove Unresolved Hints", Once, new ResolveHints.RemoveAllHints(conf)), + Batch("Ansi Standard", Once, ANSISQLStandard.ANSIStandardCastRules: _*), Batch("Nondeterministic", Once, PullOutNondeterministic), Batch("UDF", Once, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ANSISQL/AnsiSqlCastToDecimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ANSISQL/AnsiSqlCastToDecimal.scala new file mode 100644 index 0000000000000..a4804d1e826ed --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ANSISQL/AnsiSqlCastToDecimal.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.ANSISQL + +import java.math.{BigDecimal => JavaBigDecimal} + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.{Cast, CastBase, Expression, TimeZoneAwareExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +case class AnsiSqlCastToDecimal(child: Expression, timeZoneId: Option[String]) + extends CastBase { + + override def dataType: DataType = DecimalType.defaultConcreteType + + override def toString: String = s"AnsiSqlCastToDecimal($child as ${dataType.simpleString})" + + override def nullable: Boolean = child.nullable + + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = + copy(timeZoneId = Option(timeZoneId)) + + override def checkInputDataTypes(): TypeCheckResult = child.dataType match { + case StringType | LongType | IntegerType | NullType | FloatType | ShortType | + DoubleType | ByteType | TimestampType => + TypeCheckResult.TypeCheckSuccess + case _ => + TypeCheckResult.TypeCheckFailure(s"cannot cast type ${child.dataType} to decimal") + } + + override def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match { + case StringType => + buildCast[UTF8String](_, s => try { + changePrecision(Decimal(new JavaBigDecimal(s.toString)), target) + } catch { + case _: NumberFormatException => + throw new AnalysisException(s"invalid input syntax for type numeric: $s") + }) + case t: IntegralType => + super.castToDecimal(from, target) + + case TimestampType => + super.castToDecimal(from, target) + + case x: FractionalType => + b => try { + changePrecision(Decimal(x.fractional.asInstanceOf[Fractional[Any]].toDouble(b)), target) + } catch { + case _: NumberFormatException => + throw new AnalysisException(s"invalid input syntax for type numeric: $x") + } + } + + override def castToDecimalCode(from: DataType, target: DecimalType, + ctx: CodegenContext): CastFunction = { + val tmp = ctx.freshVariable("tmpDecimal", classOf[Decimal]) + val canNullSafeCast = Cast.canNullSafeCastToDecimal(from, target) + from match { + case StringType => + (c, evPrim, evNull) => + code""" + try { + Decimal $tmp = Decimal.apply(new java.math.BigDecimal($c.toString())); + ${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast)} + } catch (java.lang.NumberFormatException e) { + throw new AnalysisException("invalid input syntax for type numeric: $c") + } + """ + case t: IntegralType => + super.castToDecimalCode(from, target, ctx) + + case TimestampType => + super.castToDecimalCode(from, target, ctx) + + case x: FractionalType => + // All other numeric types can be represented precisely as Doubles + (c, evPrim, evNull) => + code""" + try { + Decimal $tmp = Decimal.apply(scala.math.BigDecimal.valueOf((double) $c)); + ${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast)} + } catch (java.lang.NumberFormatException e) { + throw new AnalysisException("invalid input syntax for type numeric: $c") + } + """ + } + } + + override def sql: String = s"CAST(${child.sql} AS ${dataType.sql})" + + override protected def ansiEnabled: Boolean = true +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index fa27a48419dbb..40e74b4510f67 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -276,7 +276,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit private[this] def needsTimeZone: Boolean = Cast.needsTimeZone(child.dataType, dataType) // [[func]] assumes the input is no longer null because eval already does the null check. - @inline private[this] def buildCast[T](a: Any, func: T => Any): Any = func(a.asInstanceOf[T]) + @inline protected[this] def buildCast[T](a: Any, func: T => Any): Any = func(a.asInstanceOf[T]) private lazy val dateFormatter = DateFormatter(zoneId) private lazy val timestampFormatter = TimestampFormatter.getFractionFormatter(zoneId) @@ -606,7 +606,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit * * NOTE: this modifies `value` in-place, so don't call it on external data. */ - private[this] def changePrecision(value: Decimal, decimalType: DecimalType): Decimal = { + protected[this] def changePrecision(value: Decimal, decimalType: DecimalType): Decimal = { if (value.changePrecision(decimalType.precision, decimalType.scale)) { value } else { @@ -629,7 +629,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit decimalType.precision, decimalType.scale, Decimal.ROUND_HALF_UP, !ansiEnabled) - private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match { + protected[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match { case StringType => buildCast[UTF8String](_, s => try { // According the benchmark test, `s.toString.trim` is much faster than `s.trim.toString`. @@ -804,7 +804,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit // The function arguments are: `input`, `result` and `resultIsNull`. We don't need `inputIsNull` // in parameter list, because the returned code will be put in null safe evaluation region. - private[this] type CastFunction = (ExprValue, ExprValue, ExprValue) => Block + protected[this] type CastFunction = (ExprValue, ExprValue, ExprValue) => Block private[this] def nullSafeCastFunction( from: DataType, @@ -1093,7 +1093,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit } } - private[this] def changePrecision(d: ExprValue, decimalType: DecimalType, + protected[this] def changePrecision(d: ExprValue, decimalType: DecimalType, evPrim: ExprValue, evNull: ExprValue, canNullSafeCast: Boolean): Block = { if (canNullSafeCast) { code""" @@ -1119,7 +1119,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit } } - private[this] def castToDecimalCode( + protected[this] def castToDecimalCode( from: DataType, target: DecimalType, ctx: CodegenContext): CastFunction = { From 1686c6a0d2720606ba46eb7006d2e1177739734c Mon Sep 17 00:00:00 2001 From: root1 Date: Thu, 19 Dec 2019 17:47:00 +0530 Subject: [PATCH 02/15] Fix Review Comments --- .../catalyst/analysis/ANSISQLStandard.scala | 11 +- .../ANSISQL/AnsiSqlCastToDecimal.scala | 111 -------- .../spark/sql/catalyst/expressions/Cast.scala | 254 +++++++++++++++++- 3 files changed, 245 insertions(+), 131 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ANSISQL/AnsiSqlCastToDecimal.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ANSISQLStandard.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ANSISQLStandard.scala index c8cb68de7e4eb..9b6cb1c1004bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ANSISQLStandard.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ANSISQLStandard.scala @@ -17,12 +17,10 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.expressions.ANSISQL.AnsiSqlCastToDecimal -import org.apache.spark.sql.catalyst.expressions.Cast +import org.apache.spark.sql.catalyst.expressions.{AnsiCast, Cast} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.DecimalType object ANSISQLStandard { val ANSIStandardCastRules: Seq[Rule[LogicalPlan]] = Seq(ANSICast) @@ -31,11 +29,8 @@ object ANSISQLStandard { override def apply(plan: LogicalPlan): LogicalPlan = { if (SQLConf.get.ansiEnabled) { plan.transformExpressions { - case Cast(child, dataType, timeZoneId) - if child.dataType != DecimalType => - dataType match { - case _: DecimalType => AnsiSqlCastToDecimal(child, timeZoneId) - } + case Cast(child, dataType, timeZoneId) => + AnsiCast(child, dataType, timeZoneId) } } else { plan diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ANSISQL/AnsiSqlCastToDecimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ANSISQL/AnsiSqlCastToDecimal.scala deleted file mode 100644 index a4804d1e826ed..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ANSISQL/AnsiSqlCastToDecimal.scala +++ /dev/null @@ -1,111 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.ANSISQL - -import java.math.{BigDecimal => JavaBigDecimal} - -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.{Cast, CastBase, Expression, TimeZoneAwareExpression} -import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String - -case class AnsiSqlCastToDecimal(child: Expression, timeZoneId: Option[String]) - extends CastBase { - - override def dataType: DataType = DecimalType.defaultConcreteType - - override def toString: String = s"AnsiSqlCastToDecimal($child as ${dataType.simpleString})" - - override def nullable: Boolean = child.nullable - - override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = - copy(timeZoneId = Option(timeZoneId)) - - override def checkInputDataTypes(): TypeCheckResult = child.dataType match { - case StringType | LongType | IntegerType | NullType | FloatType | ShortType | - DoubleType | ByteType | TimestampType => - TypeCheckResult.TypeCheckSuccess - case _ => - TypeCheckResult.TypeCheckFailure(s"cannot cast type ${child.dataType} to decimal") - } - - override def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match { - case StringType => - buildCast[UTF8String](_, s => try { - changePrecision(Decimal(new JavaBigDecimal(s.toString)), target) - } catch { - case _: NumberFormatException => - throw new AnalysisException(s"invalid input syntax for type numeric: $s") - }) - case t: IntegralType => - super.castToDecimal(from, target) - - case TimestampType => - super.castToDecimal(from, target) - - case x: FractionalType => - b => try { - changePrecision(Decimal(x.fractional.asInstanceOf[Fractional[Any]].toDouble(b)), target) - } catch { - case _: NumberFormatException => - throw new AnalysisException(s"invalid input syntax for type numeric: $x") - } - } - - override def castToDecimalCode(from: DataType, target: DecimalType, - ctx: CodegenContext): CastFunction = { - val tmp = ctx.freshVariable("tmpDecimal", classOf[Decimal]) - val canNullSafeCast = Cast.canNullSafeCastToDecimal(from, target) - from match { - case StringType => - (c, evPrim, evNull) => - code""" - try { - Decimal $tmp = Decimal.apply(new java.math.BigDecimal($c.toString())); - ${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast)} - } catch (java.lang.NumberFormatException e) { - throw new AnalysisException("invalid input syntax for type numeric: $c") - } - """ - case t: IntegralType => - super.castToDecimalCode(from, target, ctx) - - case TimestampType => - super.castToDecimalCode(from, target, ctx) - - case x: FractionalType => - // All other numeric types can be represented precisely as Doubles - (c, evPrim, evNull) => - code""" - try { - Decimal $tmp = Decimal.apply(scala.math.BigDecimal.valueOf((double) $c)); - ${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast)} - } catch (java.lang.NumberFormatException e) { - throw new AnalysisException("invalid input syntax for type numeric: $c") - } - """ - } - } - - override def sql: String = s"CAST(${child.sql} AS ${dataType.sql})" - - override protected def ansiEnabled: Boolean = true -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 40e74b4510f67..7f99642d3b6bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -23,6 +23,7 @@ import java.util.Locale import java.util.concurrent.TimeUnit._ import org.apache.spark.SparkException +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -481,7 +482,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit } // LongConverter - private[this] def castToLong(from: DataType): Any => Any = from match { + protected[this] def castToLong(from: DataType): Any => Any = from match { case StringType => val result = new LongWrapper() buildCast[UTF8String](_, s => if (s.toLong(result)) result.value else null) @@ -498,7 +499,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit } // IntConverter - private[this] def castToInt(from: DataType): Any => Any = from match { + protected[this] def castToInt(from: DataType): Any => Any = from match { case StringType => val result = new IntWrapper() buildCast[UTF8String](_, s => if (s.toInt(result)) result.value else null) @@ -517,7 +518,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit } // ShortConverter - private[this] def castToShort(from: DataType): Any => Any = from match { + protected[this] def castToShort(from: DataType): Any => Any = from match { case StringType => val result = new IntWrapper() buildCast[UTF8String](_, s => if (s.toShort(result)) { @@ -558,7 +559,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit } // ByteConverter - private[this] def castToByte(from: DataType): Any => Any = from match { + protected[this] def castToByte(from: DataType): Any => Any = from match { case StringType => val result = new IntWrapper() buildCast[UTF8String](_, s => if (s.toByte(result)) { @@ -658,7 +659,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit } // DoubleConverter - private[this] def castToDouble(from: DataType): Any => Any = from match { + protected[this] def castToDouble(from: DataType): Any => Any = from match { case StringType => buildCast[UTF8String](_, s => { val doubleStr = s.toString @@ -678,7 +679,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit } // FloatConverter - private[this] def castToFloat(from: DataType): Any => Any = from match { + protected[this] def castToFloat(from: DataType): Any => Any = from match { case StringType => buildCast[UTF8String](_, s => { val floatStr = s.toString @@ -1354,7 +1355,9 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit """ } - private[this] def castToByteCode(from: DataType, ctx: CodegenContext): CastFunction = from match { + protected[this] def castToByteCode( + from: DataType, + ctx: CodegenContext): CastFunction = from match { case StringType => val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper]) (c, evPrim, evNull) => @@ -1383,7 +1386,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit (c, evPrim, evNull) => code"$evPrim = (byte) $c;" } - private[this] def castToShortCode( + protected[this] def castToShortCode( from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => @@ -1414,7 +1417,9 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit (c, evPrim, evNull) => code"$evPrim = (short) $c;" } - private[this] def castToIntCode(from: DataType, ctx: CodegenContext): CastFunction = from match { + protected[this] def castToIntCode( + from: DataType, + ctx: CodegenContext): CastFunction = from match { case StringType => val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper]) (c, evPrim, evNull) => @@ -1442,7 +1447,9 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit (c, evPrim, evNull) => code"$evPrim = (int) $c;" } - private[this] def castToLongCode(from: DataType, ctx: CodegenContext): CastFunction = from match { + protected[this] def castToLongCode( + from: DataType, + ctx: CodegenContext): CastFunction = from match { case StringType => val wrapper = ctx.freshVariable("longWrapper", classOf[UTF8String.LongWrapper]) @@ -1471,7 +1478,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit (c, evPrim, evNull) => code"$evPrim = (long) $c;" } - private[this] def castToFloatCode(from: DataType, ctx: CodegenContext): CastFunction = { + protected[this] def castToFloatCode(from: DataType, ctx: CodegenContext): CastFunction = { from match { case StringType => val floatStr = ctx.freshVariable("floatStr", StringType) @@ -1502,7 +1509,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit } } - private[this] def castToDoubleCode(from: DataType, ctx: CodegenContext): CastFunction = { + protected[this] def castToDoubleCode(from: DataType, ctx: CodegenContext): CastFunction = { from match { case StringType => val doubleStr = ctx.freshVariable("doubleStr", StringType) @@ -1689,6 +1696,229 @@ case class AnsiCast(child: Expression, dataType: DataType, timeZoneId: Option[St copy(timeZoneId = Option(timeZoneId)) override protected val ansiEnabled: Boolean = true + + override def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match { + case StringType => + buildCast[UTF8String](_, s => try { + changePrecision(Decimal(new JavaBigDecimal(s.toString)), target) + } catch { + case _: NumberFormatException => + throw new AnalysisException(s"invalid input syntax for type numeric: $s") + }) + case _ => super.castToDecimal(from, target) + } + + override def castToDouble(from: DataType): Any => Any = from match { + case StringType => + buildCast[UTF8String](_, s => { + val doubleStr = s.toString + try doubleStr.toDouble catch { + case _: NumberFormatException => + val d = Cast.processFloatingPointSpecialLiterals(doubleStr, true) + if(d == null) { + throw new AnalysisException(s"invalid input syntax for type numeric: $s") + } else { + d.asInstanceOf[Double].doubleValue() + } + } + }) + case _ => super.castToDouble(from) + } + + override def castToFloat(from: DataType): Any => Any = from match { + case StringType => + buildCast[UTF8String](_, s => { + val floatStr = s.toString + try floatStr.toFloat catch { + case _: NumberFormatException => + val f = Cast.processFloatingPointSpecialLiterals(floatStr, true) + if (f == null) { + throw new AnalysisException(s"invalid input syntax for type numeric: $s") + } else { + f.asInstanceOf[Float].floatValue() + } + } + }) + case _ => + super.castToFloat(from) + } + + override def castToLong(from: DataType): Any => Any = from match { + case StringType => + val result = new LongWrapper() + buildCast[UTF8String](_, s => if (s.toLong(result)) result.value + else throw new AnalysisException(s"invalid input syntax for type numeric: $s")) + case _ => + super.castToLong(from) + } + + override def castToInt(from: DataType): Any => Any = from match { + case StringType => + val result = new IntWrapper() + buildCast[UTF8String](_, s => if (s.toInt(result)) result.value + else throw new AnalysisException(s"invalid input syntax for type numeric: $s")) + case _ => + super.castToInt(from) + } + + override def castToShort(from: DataType): Any => Any = from match { + case StringType => + val result = new IntWrapper() + buildCast[UTF8String](_, s => if (s.toShort(result)) { + result.value.toShort + } else { + throw new AnalysisException(s"invalid input syntax for type numeric: $s") + }) + case _ => + super.castToShort(from) + } + + override def castToByte(from: DataType): Any => Any = from match { + case StringType => + val result = new IntWrapper() + buildCast[UTF8String](_, s => if (s.toByte(result)) { + result.value.toByte + } else { + throw new AnalysisException(s"invalid input syntax for type numeric: $s") + }) + case _ => + super.castToByte(from) + } + + override def castToDecimalCode( + from: DataType, + target: DecimalType, + ctx: CodegenContext): CastFunction = { + val tmp = ctx.freshVariable("tmpDecimal", classOf[Decimal]) + val canNullSafeCast = Cast.canNullSafeCastToDecimal(from, target) + from match { + case StringType => + (c, evPrim, evNull) => + code""" + try { + Decimal $tmp = Decimal.apply(new java.math.BigDecimal($c.toString())); + ${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast)} + } catch (java.lang.NumberFormatException e) { + throw new AnalysisException("invalid input syntax for type numeric: $c") + } + """ + case _ => super.castToDecimalCode(from, target, ctx) + } + } + + override def castToByteCode(from: DataType, ctx: CodegenContext): CastFunction = from match { + case StringType => + val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper]) + (c, evPrim, evNull) => + code""" + UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); + if ($c.toByte($wrapper)) { + $evPrim = (byte) $wrapper.value; + } else { + throw new AnalysisException("invalid input syntax for type numeric: $c") + } + $wrapper = null; + """ + case _ => + super.castToByteCode(from, ctx) + } + + override def castToShortCode(from: DataType, ctx: CodegenContext): CastFunction = from match { + case StringType => + val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper]) + (c, evPrim, evNull) => + code""" + UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); + if ($c.toShort($wrapper)) { + $evPrim = (short) $wrapper.value; + } else { + throw new AnalysisException("invalid input syntax for type numeric: $c") + } + $wrapper = null; + """ + case _ => + super.castToShortCode(from, ctx) + } + + override def castToIntCode(from: DataType, ctx: CodegenContext): CastFunction = from match { + case StringType => + val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper]) + (c, evPrim, evNull) => + code""" + UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); + if ($c.toInt($wrapper)) { + $evPrim = $wrapper.value; + } else { + throw new AnalysisException("invalid input syntax for type numeric: $c") + } + $wrapper = null; + """ + case _ => + super.castToIntCode(from, ctx) + } + + override def castToLongCode(from: DataType, ctx: CodegenContext): CastFunction = from match { + case StringType => + val wrapper = ctx.freshVariable("longWrapper", classOf[UTF8String.LongWrapper]) + (c, evPrim, evNull) => + code""" + UTF8String.LongWrapper $wrapper = new UTF8String.LongWrapper(); + if ($c.toLong($wrapper)) { + $evPrim = $wrapper.value; + } else { + throw new AnalysisException("invalid input syntax for type numeric: $c") + } + $wrapper = null; + """ + case _ => + super.castToLongCode(from, ctx) + } + + override def castToFloatCode(from: DataType, ctx: CodegenContext): CastFunction = { + from match { + case StringType => + val floatStr = ctx.freshVariable("floatStr", StringType) + (c, evPrim, evNull) => + code""" + final String $floatStr = $c.toString(); + try { + $evPrim = Float.valueOf($floatStr); + } catch (java.lang.NumberFormatException e) { + final Float f = (Float) Cast.processFloatingPointSpecialLiterals($floatStr, true); + if (f == null) { + throw new AnalysisException("invalid input syntax for type numeric: $c") + } else { + $evPrim = f.floatValue(); + } + } + """ + case _ => + super.castToFloatCode(from, ctx) + } + } + + override def castToDoubleCode(from: DataType, ctx: CodegenContext): CastFunction = { + from match { + case StringType => + val doubleStr = ctx.freshVariable("doubleStr", StringType) + (c, evPrim, evNull) => + code""" + final String $doubleStr = $c.toString(); + try { + $evPrim = Double.valueOf($doubleStr); + } catch (java.lang.NumberFormatException e) { + final Double d = (Double) Cast.processFloatingPointSpecialLiterals($doubleStr, false); + if (d == null) { + throw new AnalysisException("invalid input syntax for type numeric: $c") + } else { + $evPrim = d.doubleValue(); + } + } + """ + case _ => + super.castToDoubleCode(from, ctx) + } + } } /** From 69ee2310d88cabe427d071df32d545865d882313 Mon Sep 17 00:00:00 2001 From: root1 Date: Fri, 27 Dec 2019 19:13:29 +0530 Subject: [PATCH 03/15] Fix --- .../sql/catalyst/analysis/Analyzer.scala | 1 - .../spark/sql/catalyst/expressions/Cast.scala | 384 ++++++------------ .../sql/catalyst/expressions/CastSuite.scala | 221 ++++++---- 3 files changed, 278 insertions(+), 328 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 868d445edb5ab..2438ef9218224 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -235,7 +235,6 @@ class Analyzer( Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*), Batch("Remove Unresolved Hints", Once, new ResolveHints.RemoveAllHints(conf)), - Batch("Ansi Standard", Once, ANSISQLStandard.ANSIStandardCastRules: _*), Batch("Nondeterministic", Once, PullOutNondeterministic), Batch("UDF", Once, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 7f99642d3b6bb..7d771a235f3cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -277,7 +277,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit private[this] def needsTimeZone: Boolean = Cast.needsTimeZone(child.dataType, dataType) // [[func]] assumes the input is no longer null because eval already does the null check. - @inline protected[this] def buildCast[T](a: Any, func: T => Any): Any = func(a.asInstanceOf[T]) + @inline private[this] def buildCast[T](a: Any, func: T => Any): Any = func(a.asInstanceOf[T]) private lazy val dateFormatter = DateFormatter(zoneId) private lazy val timestampFormatter = TimestampFormatter.getFractionFormatter(zoneId) @@ -482,7 +482,11 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit } // LongConverter - protected[this] def castToLong(from: DataType): Any => Any = from match { + private[this] def castToLong(from: DataType): Any => Any = from match { + case StringType if ansiEnabled => + val result = new LongWrapper() + buildCast[UTF8String](_, s => if (s.toLong(result)) result.value + else throw new IllegalArgumentException(s"invalid input syntax for type numeric: $s")) case StringType => val result = new LongWrapper() buildCast[UTF8String](_, s => if (s.toLong(result)) result.value else null) @@ -499,7 +503,11 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit } // IntConverter - protected[this] def castToInt(from: DataType): Any => Any = from match { + private[this] def castToInt(from: DataType): Any => Any = from match { + case StringType if ansiEnabled => + val result = new IntWrapper() + buildCast[UTF8String](_, s => if (s.toInt(result)) result.value + else throw new IllegalArgumentException(s"invalid input syntax for type numeric: $s")) case StringType => val result = new IntWrapper() buildCast[UTF8String](_, s => if (s.toInt(result)) result.value else null) @@ -518,13 +526,17 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit } // ShortConverter - protected[this] def castToShort(from: DataType): Any => Any = from match { + private[this] def castToShort(from: DataType): Any => Any = from match { case StringType => val result = new IntWrapper() buildCast[UTF8String](_, s => if (s.toShort(result)) { result.value.toShort } else { - null + if (ansiEnabled) { + throw new IllegalArgumentException(s"invalid input syntax for type numeric: $s") + } else { + null + } }) case BooleanType => buildCast[Boolean](_, b => if (b) 1.toShort else 0.toShort) @@ -559,13 +571,17 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit } // ByteConverter - protected[this] def castToByte(from: DataType): Any => Any = from match { + private[this] def castToByte(from: DataType): Any => Any = from match { case StringType => val result = new IntWrapper() buildCast[UTF8String](_, s => if (s.toByte(result)) { result.value.toByte } else { - null + if (ansiEnabled) { + throw new IllegalArgumentException(s"invalid input syntax for type numeric: $s") + } else { + null + } }) case BooleanType => buildCast[Boolean](_, b => if (b) 1.toByte else 0.toByte) @@ -607,7 +623,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit * * NOTE: this modifies `value` in-place, so don't call it on external data. */ - protected[this] def changePrecision(value: Decimal, decimalType: DecimalType): Decimal = { + private[this] def changePrecision(value: Decimal, decimalType: DecimalType): Decimal = { if (value.changePrecision(decimalType.precision, decimalType.scale)) { value } else { @@ -630,14 +646,19 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit decimalType.precision, decimalType.scale, Decimal.ROUND_HALF_UP, !ansiEnabled) - protected[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match { + private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match { case StringType => buildCast[UTF8String](_, s => try { // According the benchmark test, `s.toString.trim` is much faster than `s.trim.toString`. // Please refer to https://github.com/apache/spark/pull/26640 changePrecision(Decimal(new JavaBigDecimal(s.toString.trim)), target) } catch { - case _: NumberFormatException => null + case _: NumberFormatException => + if (ansiEnabled) { + throw new IllegalArgumentException(s"invalid input syntax for type numeric: $s") + } else { + null + } }) case BooleanType => buildCast[Boolean](_, b => toPrecision(if (b) Decimal.ONE else Decimal.ZERO, target)) @@ -659,7 +680,20 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit } // DoubleConverter - protected[this] def castToDouble(from: DataType): Any => Any = from match { + private[this] def castToDouble(from: DataType): Any => Any = from match { + case StringType if ansiEnabled => + buildCast[UTF8String](_, s => { + val doubleStr = s.toString + try doubleStr.toDouble catch { + case _: NumberFormatException => + val d = Cast.processFloatingPointSpecialLiterals(doubleStr, false) + if(d == null) { + throw new IllegalArgumentException(s"invalid input syntax for type numeric: $s") + } else { + d.asInstanceOf[Double].doubleValue() + } + } + }) case StringType => buildCast[UTF8String](_, s => { val doubleStr = s.toString @@ -679,7 +713,20 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit } // FloatConverter - protected[this] def castToFloat(from: DataType): Any => Any = from match { + private[this] def castToFloat(from: DataType): Any => Any = from match { + case StringType if ansiEnabled => + buildCast[UTF8String](_, s => { + val floatStr = s.toString + try floatStr.toFloat catch { + case _: NumberFormatException => + val f = Cast.processFloatingPointSpecialLiterals(floatStr, true) + if (f == null) { + throw new IllegalArgumentException(s"invalid input syntax for type numeric: $s") + } else { + f.asInstanceOf[Float].floatValue() + } + } + }) case StringType => buildCast[UTF8String](_, s => { val floatStr = s.toString @@ -805,7 +852,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit // The function arguments are: `input`, `result` and `resultIsNull`. We don't need `inputIsNull` // in parameter list, because the returned code will be put in null safe evaluation region. - protected[this] type CastFunction = (ExprValue, ExprValue, ExprValue) => Block + private[this] type CastFunction = (ExprValue, ExprValue, ExprValue) => Block private[this] def nullSafeCastFunction( from: DataType, @@ -1094,7 +1141,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit } } - protected[this] def changePrecision(d: ExprValue, decimalType: DecimalType, + private[this] def changePrecision(d: ExprValue, decimalType: DecimalType, evPrim: ExprValue, evNull: ExprValue, canNullSafeCast: Boolean): Block = { if (canNullSafeCast) { code""" @@ -1120,7 +1167,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit } } - protected[this] def castToDecimalCode( + private[this] def castToDecimalCode( from: DataType, target: DecimalType, ctx: CodegenContext): CastFunction = { @@ -1131,10 +1178,14 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit (c, evPrim, evNull) => code""" try { - Decimal $tmp = Decimal.apply(new java.math.BigDecimal($c.toString().trim())); + Decimal $tmp = Decimal.apply(new java.math.BigDecimal($c.toString())); ${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast)} } catch (java.lang.NumberFormatException e) { - $evNull = true; + if ($ansiEnabled) { + throw new IllegalArgumentException("invalid input syntax for type numeric: $c"); + } else { + $evNull =true; + } } """ case BooleanType => @@ -1355,7 +1406,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit """ } - protected[this] def castToByteCode( + private[this] def castToByteCode( from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => @@ -1366,10 +1417,14 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit if ($c.toByte($wrapper)) { $evPrim = (byte) $wrapper.value; } else { - $evNull = true; + if ($ansiEnabled) { + throw new IllegalArgumentException("invalid input syntax for type numeric: $c"); + } else { + $evNull = true; + } } $wrapper = null; - """ + """ case BooleanType => (c, evPrim, evNull) => code"$evPrim = $c ? (byte) 1 : (byte) 0;" case DateType => @@ -1386,7 +1441,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit (c, evPrim, evNull) => code"$evPrim = (byte) $c;" } - protected[this] def castToShortCode( + private[this] def castToShortCode( from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => @@ -1397,7 +1452,11 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit if ($c.toShort($wrapper)) { $evPrim = (short) $wrapper.value; } else { - $evNull = true; + if ($ansiEnabled) { + throw new IllegalArgumentException("invalid input syntax for type numeric: $c"); + } else { + $evNull = true; + } } $wrapper = null; """ @@ -1417,9 +1476,21 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit (c, evPrim, evNull) => code"$evPrim = (short) $c;" } - protected[this] def castToIntCode( + private[this] def castToIntCode( from: DataType, ctx: CodegenContext): CastFunction = from match { + case StringType if ansiEnabled => + val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper]) + (c, evPrim, evNull) => + code""" + UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); + if ($c.toInt($wrapper)) { + $evPrim = $wrapper.value; + } else { + throw new IllegalArgumentException("invalid input syntax for type numeric: $c"); + } + $wrapper = null; + """ case StringType => val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper]) (c, evPrim, evNull) => @@ -1447,7 +1518,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit (c, evPrim, evNull) => code"$evPrim = (int) $c;" } - protected[this] def castToLongCode( + private[this] def castToLongCode( from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => @@ -1459,7 +1530,11 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit if ($c.toLong($wrapper)) { $evPrim = $wrapper.value; } else { - $evNull = true; + if ($ansiEnabled) { + throw new IllegalArgumentException("invalid input syntax for type numeric: $c"); + } else { + $evNull = true; + } } $wrapper = null; """ @@ -1478,8 +1553,24 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit (c, evPrim, evNull) => code"$evPrim = (long) $c;" } - protected[this] def castToFloatCode(from: DataType, ctx: CodegenContext): CastFunction = { + private[this] def castToFloatCode(from: DataType, ctx: CodegenContext): CastFunction = { from match { + case StringType if ansiEnabled => + val floatStr = ctx.freshVariable("floatStr", StringType) + (c, evPrim, evNull) => + code""" + final String $floatStr = $c.toString(); + try { + $evPrim = Float.valueOf($floatStr); + } catch (java.lang.NumberFormatException e) { + final Float f = (Float) Cast.processFloatingPointSpecialLiterals($floatStr, true); + if (f == null) { + throw new IllegalArgumentException("invalid input syntax for type numeric: $c"); + } else { + $evPrim = f.floatValue(); + } + } + """ case StringType => val floatStr = ctx.freshVariable("floatStr", StringType) (c, evPrim, evNull) => @@ -1509,8 +1600,24 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit } } - protected[this] def castToDoubleCode(from: DataType, ctx: CodegenContext): CastFunction = { + private[this] def castToDoubleCode(from: DataType, ctx: CodegenContext): CastFunction = { from match { + case StringType if ansiEnabled => + val doubleStr = ctx.freshVariable("doubleStr", StringType) + (c, evPrim, evNull) => + code""" + final String $doubleStr = $c.toString(); + try { + $evPrim = Double.valueOf($doubleStr); + } catch (java.lang.NumberFormatException e) { + final Double d = (Double) Cast.processFloatingPointSpecialLiterals($doubleStr, false); + if (d == null) { + throw new AnalysisException("invalid input syntax for type numeric: $c"); + } else { + $evPrim = d.doubleValue(); + } + } + """ case StringType => val doubleStr = ctx.freshVariable("doubleStr", StringType) (c, evPrim, evNull) => @@ -1696,229 +1803,6 @@ case class AnsiCast(child: Expression, dataType: DataType, timeZoneId: Option[St copy(timeZoneId = Option(timeZoneId)) override protected val ansiEnabled: Boolean = true - - override def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match { - case StringType => - buildCast[UTF8String](_, s => try { - changePrecision(Decimal(new JavaBigDecimal(s.toString)), target) - } catch { - case _: NumberFormatException => - throw new AnalysisException(s"invalid input syntax for type numeric: $s") - }) - case _ => super.castToDecimal(from, target) - } - - override def castToDouble(from: DataType): Any => Any = from match { - case StringType => - buildCast[UTF8String](_, s => { - val doubleStr = s.toString - try doubleStr.toDouble catch { - case _: NumberFormatException => - val d = Cast.processFloatingPointSpecialLiterals(doubleStr, true) - if(d == null) { - throw new AnalysisException(s"invalid input syntax for type numeric: $s") - } else { - d.asInstanceOf[Double].doubleValue() - } - } - }) - case _ => super.castToDouble(from) - } - - override def castToFloat(from: DataType): Any => Any = from match { - case StringType => - buildCast[UTF8String](_, s => { - val floatStr = s.toString - try floatStr.toFloat catch { - case _: NumberFormatException => - val f = Cast.processFloatingPointSpecialLiterals(floatStr, true) - if (f == null) { - throw new AnalysisException(s"invalid input syntax for type numeric: $s") - } else { - f.asInstanceOf[Float].floatValue() - } - } - }) - case _ => - super.castToFloat(from) - } - - override def castToLong(from: DataType): Any => Any = from match { - case StringType => - val result = new LongWrapper() - buildCast[UTF8String](_, s => if (s.toLong(result)) result.value - else throw new AnalysisException(s"invalid input syntax for type numeric: $s")) - case _ => - super.castToLong(from) - } - - override def castToInt(from: DataType): Any => Any = from match { - case StringType => - val result = new IntWrapper() - buildCast[UTF8String](_, s => if (s.toInt(result)) result.value - else throw new AnalysisException(s"invalid input syntax for type numeric: $s")) - case _ => - super.castToInt(from) - } - - override def castToShort(from: DataType): Any => Any = from match { - case StringType => - val result = new IntWrapper() - buildCast[UTF8String](_, s => if (s.toShort(result)) { - result.value.toShort - } else { - throw new AnalysisException(s"invalid input syntax for type numeric: $s") - }) - case _ => - super.castToShort(from) - } - - override def castToByte(from: DataType): Any => Any = from match { - case StringType => - val result = new IntWrapper() - buildCast[UTF8String](_, s => if (s.toByte(result)) { - result.value.toByte - } else { - throw new AnalysisException(s"invalid input syntax for type numeric: $s") - }) - case _ => - super.castToByte(from) - } - - override def castToDecimalCode( - from: DataType, - target: DecimalType, - ctx: CodegenContext): CastFunction = { - val tmp = ctx.freshVariable("tmpDecimal", classOf[Decimal]) - val canNullSafeCast = Cast.canNullSafeCastToDecimal(from, target) - from match { - case StringType => - (c, evPrim, evNull) => - code""" - try { - Decimal $tmp = Decimal.apply(new java.math.BigDecimal($c.toString())); - ${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast)} - } catch (java.lang.NumberFormatException e) { - throw new AnalysisException("invalid input syntax for type numeric: $c") - } - """ - case _ => super.castToDecimalCode(from, target, ctx) - } - } - - override def castToByteCode(from: DataType, ctx: CodegenContext): CastFunction = from match { - case StringType => - val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper]) - (c, evPrim, evNull) => - code""" - UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); - if ($c.toByte($wrapper)) { - $evPrim = (byte) $wrapper.value; - } else { - throw new AnalysisException("invalid input syntax for type numeric: $c") - } - $wrapper = null; - """ - case _ => - super.castToByteCode(from, ctx) - } - - override def castToShortCode(from: DataType, ctx: CodegenContext): CastFunction = from match { - case StringType => - val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper]) - (c, evPrim, evNull) => - code""" - UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); - if ($c.toShort($wrapper)) { - $evPrim = (short) $wrapper.value; - } else { - throw new AnalysisException("invalid input syntax for type numeric: $c") - } - $wrapper = null; - """ - case _ => - super.castToShortCode(from, ctx) - } - - override def castToIntCode(from: DataType, ctx: CodegenContext): CastFunction = from match { - case StringType => - val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper]) - (c, evPrim, evNull) => - code""" - UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); - if ($c.toInt($wrapper)) { - $evPrim = $wrapper.value; - } else { - throw new AnalysisException("invalid input syntax for type numeric: $c") - } - $wrapper = null; - """ - case _ => - super.castToIntCode(from, ctx) - } - - override def castToLongCode(from: DataType, ctx: CodegenContext): CastFunction = from match { - case StringType => - val wrapper = ctx.freshVariable("longWrapper", classOf[UTF8String.LongWrapper]) - (c, evPrim, evNull) => - code""" - UTF8String.LongWrapper $wrapper = new UTF8String.LongWrapper(); - if ($c.toLong($wrapper)) { - $evPrim = $wrapper.value; - } else { - throw new AnalysisException("invalid input syntax for type numeric: $c") - } - $wrapper = null; - """ - case _ => - super.castToLongCode(from, ctx) - } - - override def castToFloatCode(from: DataType, ctx: CodegenContext): CastFunction = { - from match { - case StringType => - val floatStr = ctx.freshVariable("floatStr", StringType) - (c, evPrim, evNull) => - code""" - final String $floatStr = $c.toString(); - try { - $evPrim = Float.valueOf($floatStr); - } catch (java.lang.NumberFormatException e) { - final Float f = (Float) Cast.processFloatingPointSpecialLiterals($floatStr, true); - if (f == null) { - throw new AnalysisException("invalid input syntax for type numeric: $c") - } else { - $evPrim = f.floatValue(); - } - } - """ - case _ => - super.castToFloatCode(from, ctx) - } - } - - override def castToDoubleCode(from: DataType, ctx: CodegenContext): CastFunction = { - from match { - case StringType => - val doubleStr = ctx.freshVariable("doubleStr", StringType) - (c, evPrim, evNull) => - code""" - final String $doubleStr = $c.toString(); - try { - $evPrim = Double.valueOf($doubleStr); - } catch (java.lang.NumberFormatException e) { - final Double d = (Double) Cast.processFloatingPointSpecialLiterals($doubleStr, false); - if (d == null) { - throw new AnalysisException("invalid input syntax for type numeric: $c") - } else { - $evPrim = d.doubleValue(); - } - } - """ - case _ => - super.castToDoubleCode(from, ctx) - } - } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 2d8f22c34ade7..4371e8ca4c4bf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -284,7 +284,6 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { val gmtId = Option("GMT") checkEvaluation(cast("abdef", StringType), "abdef") - checkEvaluation(cast("abdef", DecimalType.USER_DEFAULT), null) checkEvaluation(cast("abdef", TimestampType, gmtId), null) checkEvaluation(cast("12.65", DecimalType.SYSTEM_DEFAULT), Decimal(12.65)) @@ -324,7 +323,6 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast("23", DecimalType.USER_DEFAULT), Decimal(23)) checkEvaluation(cast("23", ByteType), 23.toByte) checkEvaluation(cast("23", ShortType), 23.toShort) - checkEvaluation(cast("2012-12-11", DoubleType), null) checkEvaluation(cast(123, IntegerType), 123) checkEvaluation(cast(Literal.create(null, IntegerType), ShortType), null) @@ -410,15 +408,6 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { checkNullCast(ArrayType(StringType), ArrayType(IntegerType)) - { - val ret = cast(array, ArrayType(IntegerType, containsNull = true)) - assert(ret.resolved) - checkEvaluation(ret, Seq(123, null, null, null)) - } - { - val ret = cast(array, ArrayType(IntegerType, containsNull = false)) - assert(ret.resolved === false) - } { val ret = cast(array, ArrayType(BooleanType, containsNull = true)) assert(ret.resolved) @@ -429,15 +418,6 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { assert(ret.resolved === false) } - { - val ret = cast(array_notNull, ArrayType(IntegerType, containsNull = true)) - assert(ret.resolved) - checkEvaluation(ret, Seq(123, null, null)) - } - { - val ret = cast(array_notNull, ArrayType(IntegerType, containsNull = false)) - assert(ret.resolved === false) - } { val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = true)) assert(ret.resolved) @@ -464,15 +444,6 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { checkNullCast(MapType(StringType, IntegerType), MapType(StringType, StringType)) - { - val ret = cast(map, MapType(StringType, IntegerType, valueContainsNull = true)) - assert(ret.resolved) - checkEvaluation(ret, Map("a" -> 123, "b" -> null, "c" -> null, "d" -> null)) - } - { - val ret = cast(map, MapType(StringType, IntegerType, valueContainsNull = false)) - assert(ret.resolved === false) - } { val ret = cast(map, MapType(StringType, BooleanType, valueContainsNull = true)) assert(ret.resolved) @@ -486,16 +457,6 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { val ret = cast(map, MapType(IntegerType, StringType, valueContainsNull = true)) assert(ret.resolved === false) } - - { - val ret = cast(map_notNull, MapType(StringType, IntegerType, valueContainsNull = true)) - assert(ret.resolved) - checkEvaluation(ret, Map("a" -> 123, "b" -> null, "c" -> null)) - } - { - val ret = cast(map_notNull, MapType(StringType, IntegerType, valueContainsNull = false)) - assert(ret.resolved === false) - } { val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = true)) assert(ret.resolved) @@ -546,23 +507,6 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { StructField("b", StringType, nullable = false), StructField("c", StringType, nullable = false)))) - { - val ret = cast(struct, StructType(Seq( - StructField("a", IntegerType, nullable = true), - StructField("b", IntegerType, nullable = true), - StructField("c", IntegerType, nullable = true), - StructField("d", IntegerType, nullable = true)))) - assert(ret.resolved) - checkEvaluation(ret, InternalRow(123, null, null, null)) - } - { - val ret = cast(struct, StructType(Seq( - StructField("a", IntegerType, nullable = true), - StructField("b", IntegerType, nullable = true), - StructField("c", IntegerType, nullable = false), - StructField("d", IntegerType, nullable = true)))) - assert(ret.resolved === false) - } { val ret = cast(struct, StructType(Seq( StructField("a", BooleanType, nullable = true), @@ -581,21 +525,6 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { assert(ret.resolved === false) } - { - val ret = cast(struct_notNull, StructType(Seq( - StructField("a", IntegerType, nullable = true), - StructField("b", IntegerType, nullable = true), - StructField("c", IntegerType, nullable = true)))) - assert(ret.resolved) - checkEvaluation(ret, InternalRow(123, null, null)) - } - { - val ret = cast(struct_notNull, StructType(Seq( - StructField("a", IntegerType, nullable = true), - StructField("b", IntegerType, nullable = true), - StructField("c", IntegerType, nullable = false)))) - assert(ret.resolved === false) - } { val ret = cast(struct_notNull, StructType(Seq( StructField("a", BooleanType, nullable = true), @@ -921,11 +850,6 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { Seq("nan", "nAn", " nan ").foreach { value => checkEvaluation(cast(value, DoubleType), Double.NaN) } - - // Invalid literals when casted to double and float results in null. - Seq(DoubleType, FloatType).foreach { dataType => - checkEvaluation(cast("badvalue", dataType), null) - } } private def testIntMaxAndMin(dt: DataType): Unit = { @@ -1054,7 +978,6 @@ class CastSuite extends CastSuiteBase { } } - test("cast from int") { checkCast(0, false) checkCast(1, true) @@ -1214,6 +1137,125 @@ class CastSuite extends CastSuiteBase { val set = CollectSet(Literal(1)) assert(Cast.canCast(set.dataType, ArrayType(StringType, false))) } + + test("Cast should output null when ANSI is not enabled.") { + withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + checkEvaluation(cast("abdef", DecimalType.USER_DEFAULT), null) + checkEvaluation(cast("2012-12-11", DoubleType), null) + + // cast to array + val array = Literal.create(Seq("123", "true", "f", null), + ArrayType(StringType, containsNull = true)) + val array_notNull = Literal.create(Seq("123", "true", "f"), + ArrayType(StringType, containsNull = false)) + + { + val ret = cast(array, ArrayType(IntegerType, containsNull = true)) + assert(ret.resolved) + checkEvaluation(ret, Seq(123, null, null, null)) + } + { + val ret = cast(array, ArrayType(IntegerType, containsNull = false)) + assert(ret.resolved === false) + } + { + val ret = cast(array_notNull, ArrayType(IntegerType, containsNull = true)) + assert(ret.resolved) + checkEvaluation(ret, Seq(123, null, null)) + } + { + val ret = cast(array_notNull, ArrayType(IntegerType, containsNull = false)) + assert(ret.resolved === false) + } + + // cast from map + val map = Literal.create( + Map("a" -> "123", "b" -> "true", "c" -> "f", "d" -> null), + MapType(StringType, StringType, valueContainsNull = true)) + val map_notNull = Literal.create( + Map("a" -> "123", "b" -> "true", "c" -> "f"), + MapType(StringType, StringType, valueContainsNull = false)) + + { + val ret = cast(map, MapType(StringType, IntegerType, valueContainsNull = true)) + assert(ret.resolved) + checkEvaluation(ret, Map("a" -> 123, "b" -> null, "c" -> null, "d" -> null)) + } + { + val ret = cast(map, MapType(StringType, IntegerType, valueContainsNull = false)) + assert(ret.resolved === false) + } + { + val ret = cast(map_notNull, MapType(StringType, IntegerType, valueContainsNull = true)) + assert(ret.resolved) + checkEvaluation(ret, Map("a" -> 123, "b" -> null, "c" -> null)) + } + { + val ret = cast(map_notNull, MapType(StringType, IntegerType, valueContainsNull = false)) + assert(ret.resolved === false) + } + + // cast from struct + val struct = Literal.create( + InternalRow( + UTF8String.fromString("123"), + UTF8String.fromString("true"), + UTF8String.fromString("f"), + null), + StructType(Seq( + StructField("a", StringType, nullable = true), + StructField("b", StringType, nullable = true), + StructField("c", StringType, nullable = true), + StructField("d", StringType, nullable = true)))) + val struct_notNull = Literal.create( + InternalRow( + UTF8String.fromString("123"), + UTF8String.fromString("true"), + UTF8String.fromString("f")), + StructType(Seq( + StructField("a", StringType, nullable = false), + StructField("b", StringType, nullable = false), + StructField("c", StringType, nullable = false)))) + + { + val ret = cast(struct, StructType(Seq( + StructField("a", IntegerType, nullable = true), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = true), + StructField("d", IntegerType, nullable = true)))) + assert(ret.resolved) + checkEvaluation(ret, InternalRow(123, null, null, null)) + } + { + val ret = cast(struct, StructType(Seq( + StructField("a", IntegerType, nullable = true), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false), + StructField("d", IntegerType, nullable = true)))) + assert(ret.resolved === false) + } + { + val ret = cast(struct_notNull, StructType(Seq( + StructField("a", IntegerType, nullable = true), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = true)))) + assert(ret.resolved) + checkEvaluation(ret, InternalRow(123, null, null)) + } + { + val ret = cast(struct_notNull, StructType(Seq( + StructField("a", IntegerType, nullable = true), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false)))) + assert(ret.resolved === false) + } + + // Invalid literals when casted to double and float results in null. + Seq(DoubleType, FloatType).foreach { dataType => + checkEvaluation(cast("badvalue", dataType), null) + } + } + } } /** @@ -1229,4 +1271,29 @@ class AnsiCastSuite extends CastSuiteBase { case _ => AnsiCast(Literal(v), targetType, timeZoneId) } } + + test("cast from invalid string to numeric should throw IllegalArgumentException") { + // cast to IntegerType + Seq(IntegerType, ShortType, ByteType, LongType).foreach { dataType => + val array = Literal.create(Seq("123", "true", "f", null), + ArrayType(StringType, containsNull = true)) + checkExceptionInExpression[IllegalArgumentException]( + cast(array, ArrayType(dataType, containsNull = true)), "invalid input") + checkExceptionInExpression[IllegalArgumentException]( + cast("string", dataType), "invalid input") + checkExceptionInExpression[IllegalArgumentException]( + cast("123-string", dataType), "invalid input") + checkExceptionInExpression[IllegalArgumentException]( + cast("2020-07-19", dataType), "invalid input") + } + + Seq(DoubleType, FloatType, DecimalType.USER_DEFAULT).foreach { dataType => + checkExceptionInExpression[IllegalArgumentException]( + cast("string", dataType), "invalid input") + checkExceptionInExpression[IllegalArgumentException]( + cast("123.000.00", dataType), "invalid input") + checkExceptionInExpression[IllegalArgumentException]( + cast("abc.com", dataType), "invalid input") + } + } } From 74809d09bc61edad1982bf31231d37290dde278b Mon Sep 17 00:00:00 2001 From: root1 Date: Fri, 27 Dec 2019 19:22:18 +0530 Subject: [PATCH 04/15] Fix --- .../catalyst/analysis/ANSISQLStandard.scala | 40 ------------------- 1 file changed, 40 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ANSISQLStandard.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ANSISQLStandard.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ANSISQLStandard.scala deleted file mode 100644 index 9b6cb1c1004bc..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ANSISQLStandard.scala +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.analysis - -import org.apache.spark.sql.catalyst.expressions.{AnsiCast, Cast} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.internal.SQLConf - -object ANSISQLStandard { - val ANSIStandardCastRules: Seq[Rule[LogicalPlan]] = Seq(ANSICast) - - object ANSICast extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = { - if (SQLConf.get.ansiEnabled) { - plan.transformExpressions { - case Cast(child, dataType, timeZoneId) => - AnsiCast(child, dataType, timeZoneId) - } - } else { - plan - } - } - } -} From a336084d0fb5152d4b96b6f48114dd4c24603a74 Mon Sep 17 00:00:00 2001 From: root1 Date: Fri, 27 Dec 2019 20:30:50 +0530 Subject: [PATCH 05/15] Fix --- .../org/apache/spark/sql/catalyst/expressions/Cast.scala | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 7d771a235f3cb..08fb8091468a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -23,7 +23,6 @@ import java.util.Locale import java.util.concurrent.TimeUnit._ import org.apache.spark.SparkException -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -1178,7 +1177,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit (c, evPrim, evNull) => code""" try { - Decimal $tmp = Decimal.apply(new java.math.BigDecimal($c.toString())); + Decimal $tmp = Decimal.apply(new java.math.BigDecimal($c.toString().trim())); ${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast)} } catch (java.lang.NumberFormatException e) { if ($ansiEnabled) { @@ -1424,7 +1423,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit } } $wrapper = null; - """ + """ case BooleanType => (c, evPrim, evNull) => code"$evPrim = $c ? (byte) 1 : (byte) 0;" case DateType => @@ -1612,7 +1611,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit } catch (java.lang.NumberFormatException e) { final Double d = (Double) Cast.processFloatingPointSpecialLiterals($doubleStr, false); if (d == null) { - throw new AnalysisException("invalid input syntax for type numeric: $c"); + throw new IllegalArgumentException("invalid input syntax for type numeric: $c"); } else { $evPrim = d.doubleValue(); } From c0f8baf77bd91eda49a0a170df6ef5e12ea3efda Mon Sep 17 00:00:00 2001 From: root1 Date: Mon, 30 Dec 2019 16:15:37 +0530 Subject: [PATCH 06/15] Test Cases Fix. --- .../sql-tests/inputs/postgreSQL/float4.sql | 8 +- .../sql-tests/inputs/postgreSQL/float8.sql | 8 +- .../sql-tests/inputs/postgreSQL/text.sql | 4 +- .../inputs/postgreSQL/window_part2.sql | 8 +- .../inputs/postgreSQL/window_part4.sql | 6 +- .../results/postgreSQL/float4.sql.out | 200 +++---- .../results/postgreSQL/float8.sql.out | 488 ++++++++---------- .../sql-tests/results/postgreSQL/text.sql.out | 214 ++++---- .../results/postgreSQL/window_part2.sql.out | 27 +- .../results/postgreSQL/window_part4.sql.out | 16 +- 10 files changed, 434 insertions(+), 545 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/float4.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/float4.sql index 2989569e219ff..6ddb74f23fb7c 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/float4.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/float4.sql @@ -46,14 +46,14 @@ SELECT float('infinity'); SELECT float(' -INFINiTY '); -- [SPARK-27923] Spark SQL insert there bad special inputs to NULL -- bad special inputs -SELECT float('N A N'); -SELECT float('NaN x'); -SELECT float(' INFINITY x'); +-- SELECT float('N A N'); +-- SELECT float('NaN x'); +-- SELECT float(' INFINITY x'); SELECT float('Infinity') + 100.0; SELECT float('Infinity') / float('Infinity'); SELECT float('nan') / float('nan'); -SELECT float(decimal('nan')); +-- SELECT float(decimal('nan')); SELECT '' AS five, * FROM FLOAT4_TBL; diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/float8.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/float8.sql index 932cdb95fcf3a..fabdcb0dce483 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/float8.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/float8.sql @@ -45,15 +45,15 @@ SELECT double('infinity'); SELECT double(' -INFINiTY '); -- [SPARK-27923] Spark SQL insert there bad special inputs to NULL -- bad special inputs -SELECT double('N A N'); -SELECT double('NaN x'); -SELECT double(' INFINITY x'); +-- SELECT double('N A N'); +-- SELECT double('NaN x'); +-- SELECT double(' INFINITY x'); SELECT double('Infinity') + 100.0; SELECT double('Infinity') / double('Infinity'); SELECT double('NaN') / double('NaN'); -- [SPARK-28315] Decimal can not accept NaN as input -SELECT double(decimal('nan')); +-- SELECT double(decimal('nan')); SELECT '' AS five, * FROM FLOAT8_TBL; diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/text.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/text.sql index 05953123da86f..6e56485d1a8ef 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/text.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/text.sql @@ -25,8 +25,8 @@ select length(42); -- casting to text in concatenations, so long as the other input is text or -- an unknown literal. So these work: -- [SPARK-28033] String concatenation low priority than other arithmeticBinary -select string('four: ') || 2+2; -select 'four: ' || 2+2; +-- select string('four: ') || 2+2; +-- select 'four: ' || 2+2; -- but not this: -- Spark SQL implicit cast both side to string diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part2.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part2.sql index 395149e48d5c8..62caf2378a50b 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part2.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part2.sql @@ -281,10 +281,10 @@ from numerics window w as (order by f_numeric range between 1 preceding and 1.1 following); -- currently unsupported -select id, f_numeric, first(id) over w, last(id) over w -from numerics -window w as (order by f_numeric range between - 1.1 preceding and 'NaN' following); -- error, NaN disallowed +-- select id, f_numeric, first(id) over w, last(id) over w +-- from numerics +-- window w as (order by f_numeric range between +-- 1.1 preceding and 'NaN' following); -- error, NaN disallowed drop table empsalary; drop table numerics; diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part4.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part4.sql index 64ba8e3b7a5ad..653231f3cc87c 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part4.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part4.sql @@ -368,9 +368,9 @@ SELECT i,SUM(v) OVER (ORDER BY i ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) -- [SPARK-29638] Spark handles 'NaN' as 0 in sums -- ensure aggregate over numeric properly recovers from NaN values -SELECT a, b, - SUM(b) OVER(ORDER BY A ROWS BETWEEN 1 PRECEDING AND CURRENT ROW) -FROM (VALUES(1,1),(2,2),(3,(cast('nan' as int))),(4,3),(5,4)) t(a,b); +-- SELECT a, b, +-- SUM(b) OVER(ORDER BY A ROWS BETWEEN 1 PRECEDING AND CURRENT ROW) +-- FROM (VALUES(1,1),(2,2),(3,(cast('nan' as int))),(4,3),(5,4)) t(a,b); -- It might be tempting for someone to add an inverse trans function for -- float and double precision. This should not be done as it can give incorrect diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/float4.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/float4.sql.out index 64608a349b610..82795de1a0782 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/float4.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/float4.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 43 +-- Number of queries: 39 -- !query 0 @@ -91,66 +91,34 @@ struct -- !query 11 -SELECT float('N A N') --- !query 11 schema -struct --- !query 11 output -NULL - - --- !query 12 -SELECT float('NaN x') --- !query 12 schema -struct --- !query 12 output -NULL - - --- !query 13 -SELECT float(' INFINITY x') --- !query 13 schema -struct --- !query 13 output -NULL - - --- !query 14 SELECT float('Infinity') + 100.0 --- !query 14 schema +-- !query 11 schema struct<(CAST(CAST(Infinity AS FLOAT) AS DOUBLE) + CAST(100.0 AS DOUBLE)):double> --- !query 14 output +-- !query 11 output Infinity --- !query 15 +-- !query 12 SELECT float('Infinity') / float('Infinity') --- !query 15 schema +-- !query 12 schema struct<(CAST(CAST(Infinity AS FLOAT) AS DOUBLE) / CAST(CAST(Infinity AS FLOAT) AS DOUBLE)):double> --- !query 15 output +-- !query 12 output NaN --- !query 16 +-- !query 13 SELECT float('nan') / float('nan') --- !query 16 schema +-- !query 13 schema struct<(CAST(CAST(nan AS FLOAT) AS DOUBLE) / CAST(CAST(nan AS FLOAT) AS DOUBLE)):double> --- !query 16 output +-- !query 13 output NaN --- !query 17 -SELECT float(decimal('nan')) --- !query 17 schema -struct --- !query 17 output -NULL - - --- !query 18 +-- !query 14 SELECT '' AS five, * FROM FLOAT4_TBL --- !query 18 schema +-- !query 14 schema struct --- !query 18 output +-- !query 14 output -34.84 0.0 1.2345679E-20 @@ -158,116 +126,116 @@ struct 1004.3 --- !query 19 +-- !query 15 SELECT '' AS four, f.* FROM FLOAT4_TBL f WHERE f.f1 <> '1004.3' --- !query 19 schema +-- !query 15 schema struct --- !query 19 output +-- !query 15 output -34.84 0.0 1.2345679E-20 1.2345679E20 --- !query 20 +-- !query 16 SELECT '' AS one, f.* FROM FLOAT4_TBL f WHERE f.f1 = '1004.3' --- !query 20 schema +-- !query 16 schema struct --- !query 20 output +-- !query 16 output 1004.3 --- !query 21 +-- !query 17 SELECT '' AS three, f.* FROM FLOAT4_TBL f WHERE '1004.3' > f.f1 --- !query 21 schema +-- !query 17 schema struct --- !query 21 output +-- !query 17 output -34.84 0.0 1.2345679E-20 --- !query 22 +-- !query 18 SELECT '' AS three, f.* FROM FLOAT4_TBL f WHERE f.f1 < '1004.3' --- !query 22 schema +-- !query 18 schema struct --- !query 22 output +-- !query 18 output -34.84 0.0 1.2345679E-20 --- !query 23 +-- !query 19 SELECT '' AS four, f.* FROM FLOAT4_TBL f WHERE '1004.3' >= f.f1 --- !query 23 schema +-- !query 19 schema struct --- !query 23 output +-- !query 19 output -34.84 0.0 1.2345679E-20 1004.3 --- !query 24 +-- !query 20 SELECT '' AS four, f.* FROM FLOAT4_TBL f WHERE f.f1 <= '1004.3' --- !query 24 schema +-- !query 20 schema struct --- !query 24 output +-- !query 20 output -34.84 0.0 1.2345679E-20 1004.3 --- !query 25 +-- !query 21 SELECT '' AS three, f.f1, f.f1 * '-10' AS x FROM FLOAT4_TBL f WHERE f.f1 > '0.0' --- !query 25 schema +-- !query 21 schema struct --- !query 25 output +-- !query 21 output 1.2345679E-20 -1.2345678720289608E-19 1.2345679E20 -1.2345678955701443E21 1004.3 -10042.999877929688 --- !query 26 +-- !query 22 SELECT '' AS three, f.f1, f.f1 + '-10' AS x FROM FLOAT4_TBL f WHERE f.f1 > '0.0' --- !query 26 schema +-- !query 22 schema struct --- !query 26 output +-- !query 22 output 1.2345679E-20 -10.0 1.2345679E20 1.2345678955701443E20 1004.3 994.2999877929688 --- !query 27 +-- !query 23 SELECT '' AS three, f.f1, f.f1 / '-10' AS x FROM FLOAT4_TBL f WHERE f.f1 > '0.0' --- !query 27 schema +-- !query 23 schema struct --- !query 27 output +-- !query 23 output 1.2345679E-20 -1.2345678720289608E-21 1.2345679E20 -1.2345678955701443E19 1004.3 -100.42999877929688 --- !query 28 +-- !query 24 SELECT '' AS three, f.f1, f.f1 - '-10' AS x FROM FLOAT4_TBL f WHERE f.f1 > '0.0' --- !query 28 schema +-- !query 24 schema struct --- !query 28 output +-- !query 24 output 1.2345679E-20 10.0 1.2345679E20 1.2345678955701443E20 1004.3 1014.2999877929688 --- !query 29 +-- !query 25 SELECT '' AS five, * FROM FLOAT4_TBL --- !query 29 schema +-- !query 25 schema struct --- !query 29 output +-- !query 25 output -34.84 0.0 1.2345679E-20 @@ -275,107 +243,107 @@ struct 1004.3 --- !query 30 +-- !query 26 SELECT smallint(float('32767.4')) --- !query 30 schema +-- !query 26 schema struct --- !query 30 output +-- !query 26 output 32767 --- !query 31 +-- !query 27 SELECT smallint(float('32767.6')) --- !query 31 schema +-- !query 27 schema struct --- !query 31 output +-- !query 27 output 32767 --- !query 32 +-- !query 28 SELECT smallint(float('-32768.4')) --- !query 32 schema +-- !query 28 schema struct --- !query 32 output +-- !query 28 output -32768 --- !query 33 +-- !query 29 SELECT smallint(float('-32768.6')) --- !query 33 schema +-- !query 29 schema struct --- !query 33 output +-- !query 29 output -32768 --- !query 34 +-- !query 30 SELECT int(float('2147483520')) --- !query 34 schema +-- !query 30 schema struct --- !query 34 output +-- !query 30 output 2147483520 --- !query 35 +-- !query 31 SELECT int(float('2147483647')) --- !query 35 schema +-- !query 31 schema struct --- !query 35 output +-- !query 31 output 2147483647 --- !query 36 +-- !query 32 SELECT int(float('-2147483648.5')) --- !query 36 schema +-- !query 32 schema struct --- !query 36 output +-- !query 32 output -2147483648 --- !query 37 +-- !query 33 SELECT int(float('-2147483900')) --- !query 37 schema +-- !query 33 schema struct<> --- !query 37 output +-- !query 33 output java.lang.ArithmeticException Casting -2.1474839E9 to int causes overflow --- !query 38 +-- !query 34 SELECT bigint(float('9223369837831520256')) --- !query 38 schema +-- !query 34 schema struct --- !query 38 output +-- !query 34 output 9223369837831520256 --- !query 39 +-- !query 35 SELECT bigint(float('9223372036854775807')) --- !query 39 schema +-- !query 35 schema struct --- !query 39 output +-- !query 35 output 9223372036854775807 --- !query 40 +-- !query 36 SELECT bigint(float('-9223372036854775808.5')) --- !query 40 schema +-- !query 36 schema struct --- !query 40 output +-- !query 36 output -9223372036854775808 --- !query 41 +-- !query 37 SELECT bigint(float('-9223380000000000000')) --- !query 41 schema +-- !query 37 schema struct<> --- !query 41 output +-- !query 37 output java.lang.ArithmeticException Casting -9.22338E18 to int causes overflow --- !query 42 +-- !query 38 DROP TABLE FLOAT4_TBL --- !query 42 schema +-- !query 38 schema struct<> --- !query 42 output +-- !query 38 output diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/float8.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/float8.sql.out index d38e36e956985..d6742f4c37b36 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/float8.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/float8.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 95 +-- Number of queries: 91 -- !query 0 @@ -123,66 +123,34 @@ struct -- !query 15 -SELECT double('N A N') --- !query 15 schema -struct --- !query 15 output -NULL - - --- !query 16 -SELECT double('NaN x') --- !query 16 schema -struct --- !query 16 output -NULL - - --- !query 17 -SELECT double(' INFINITY x') --- !query 17 schema -struct --- !query 17 output -NULL - - --- !query 18 SELECT double('Infinity') + 100.0 --- !query 18 schema +-- !query 15 schema struct<(CAST(Infinity AS DOUBLE) + CAST(100.0 AS DOUBLE)):double> --- !query 18 output +-- !query 15 output Infinity --- !query 19 +-- !query 16 SELECT double('Infinity') / double('Infinity') --- !query 19 schema +-- !query 16 schema struct<(CAST(Infinity AS DOUBLE) / CAST(Infinity AS DOUBLE)):double> --- !query 19 output +-- !query 16 output NaN --- !query 20 +-- !query 17 SELECT double('NaN') / double('NaN') --- !query 20 schema +-- !query 17 schema struct<(CAST(NaN AS DOUBLE) / CAST(NaN AS DOUBLE)):double> --- !query 20 output +-- !query 17 output NaN --- !query 21 -SELECT double(decimal('nan')) --- !query 21 schema -struct --- !query 21 output -NULL - - --- !query 22 +-- !query 18 SELECT '' AS five, * FROM FLOAT8_TBL --- !query 22 schema +-- !query 18 schema struct --- !query 22 output +-- !query 18 output -34.84 0.0 1.2345678901234E-200 @@ -190,121 +158,121 @@ struct 1004.3 --- !query 23 +-- !query 19 SELECT '' AS four, f.* FROM FLOAT8_TBL f WHERE f.f1 <> '1004.3' --- !query 23 schema +-- !query 19 schema struct --- !query 23 output +-- !query 19 output -34.84 0.0 1.2345678901234E-200 1.2345678901234E200 --- !query 24 +-- !query 20 SELECT '' AS one, f.* FROM FLOAT8_TBL f WHERE f.f1 = '1004.3' --- !query 24 schema +-- !query 20 schema struct --- !query 24 output +-- !query 20 output 1004.3 --- !query 25 +-- !query 21 SELECT '' AS three, f.* FROM FLOAT8_TBL f WHERE '1004.3' > f.f1 --- !query 25 schema +-- !query 21 schema struct --- !query 25 output +-- !query 21 output -34.84 0.0 1.2345678901234E-200 --- !query 26 +-- !query 22 SELECT '' AS three, f.* FROM FLOAT8_TBL f WHERE f.f1 < '1004.3' --- !query 26 schema +-- !query 22 schema struct --- !query 26 output +-- !query 22 output -34.84 0.0 1.2345678901234E-200 --- !query 27 +-- !query 23 SELECT '' AS four, f.* FROM FLOAT8_TBL f WHERE '1004.3' >= f.f1 --- !query 27 schema +-- !query 23 schema struct --- !query 27 output +-- !query 23 output -34.84 0.0 1.2345678901234E-200 1004.3 --- !query 28 +-- !query 24 SELECT '' AS four, f.* FROM FLOAT8_TBL f WHERE f.f1 <= '1004.3' --- !query 28 schema +-- !query 24 schema struct --- !query 28 output +-- !query 24 output -34.84 0.0 1.2345678901234E-200 1004.3 --- !query 29 +-- !query 25 SELECT '' AS three, f.f1, f.f1 * '-10' AS x FROM FLOAT8_TBL f WHERE f.f1 > '0.0' --- !query 29 schema +-- !query 25 schema struct --- !query 29 output +-- !query 25 output 1.2345678901234E-200 -1.2345678901234E-199 1.2345678901234E200 -1.2345678901234E201 1004.3 -10043.0 --- !query 30 +-- !query 26 SELECT '' AS three, f.f1, f.f1 + '-10' AS x FROM FLOAT8_TBL f WHERE f.f1 > '0.0' --- !query 30 schema +-- !query 26 schema struct --- !query 30 output +-- !query 26 output 1.2345678901234E-200 -10.0 1.2345678901234E200 1.2345678901234E200 1004.3 994.3 --- !query 31 +-- !query 27 SELECT '' AS three, f.f1, f.f1 / '-10' AS x FROM FLOAT8_TBL f WHERE f.f1 > '0.0' --- !query 31 schema +-- !query 27 schema struct --- !query 31 output +-- !query 27 output 1.2345678901234E-200 -1.2345678901234E-201 1.2345678901234E200 -1.2345678901234E199 1004.3 -100.42999999999999 --- !query 32 +-- !query 28 SELECT '' AS three, f.f1, f.f1 - '-10' AS x FROM FLOAT8_TBL f WHERE f.f1 > '0.0' --- !query 32 schema +-- !query 28 schema struct --- !query 32 output +-- !query 28 output 1.2345678901234E-200 10.0 1.2345678901234E200 1.2345678901234E200 1004.3 1014.3 --- !query 33 +-- !query 29 SELECT '' AS five, f.f1, round(f.f1) AS round_f1 FROM FLOAT8_TBL f --- !query 33 schema +-- !query 29 schema struct --- !query 33 output +-- !query 29 output -34.84 -35.0 0.0 0.0 1.2345678901234E-200 0.0 @@ -312,11 +280,11 @@ struct 1004.3 1004.0 --- !query 34 +-- !query 30 select ceil(f1) as ceil_f1 from float8_tbl f --- !query 34 schema +-- !query 30 schema struct --- !query 34 output +-- !query 30 output -34 0 1 @@ -324,11 +292,11 @@ struct 9223372036854775807 --- !query 35 +-- !query 31 select ceiling(f1) as ceiling_f1 from float8_tbl f --- !query 35 schema +-- !query 31 schema struct --- !query 35 output +-- !query 31 output -34 0 1 @@ -336,11 +304,11 @@ struct 9223372036854775807 --- !query 36 +-- !query 32 select floor(f1) as floor_f1 from float8_tbl f --- !query 36 schema +-- !query 32 schema struct --- !query 36 output +-- !query 32 output -35 0 0 @@ -348,11 +316,11 @@ struct 9223372036854775807 --- !query 37 +-- !query 33 select sign(f1) as sign_f1 from float8_tbl f --- !query 37 schema +-- !query 33 schema struct --- !query 37 output +-- !query 33 output -1.0 0.0 1.0 @@ -360,87 +328,87 @@ struct 1.0 --- !query 38 +-- !query 34 SELECT sqrt(double('64')) AS eight --- !query 38 schema +-- !query 34 schema struct --- !query 38 output +-- !query 34 output 8.0 --- !query 39 +-- !query 35 SELECT power(double('144'), double('0.5')) --- !query 39 schema +-- !query 35 schema struct --- !query 39 output +-- !query 35 output 12.0 --- !query 40 +-- !query 36 SELECT power(double('NaN'), double('0.5')) --- !query 40 schema +-- !query 36 schema struct --- !query 40 output +-- !query 36 output NaN --- !query 41 +-- !query 37 SELECT power(double('144'), double('NaN')) --- !query 41 schema +-- !query 37 schema struct --- !query 41 output +-- !query 37 output NaN --- !query 42 +-- !query 38 SELECT power(double('NaN'), double('NaN')) --- !query 42 schema +-- !query 38 schema struct --- !query 42 output +-- !query 38 output NaN --- !query 43 +-- !query 39 SELECT power(double('-1'), double('NaN')) --- !query 43 schema +-- !query 39 schema struct --- !query 43 output +-- !query 39 output NaN --- !query 44 +-- !query 40 SELECT power(double('1'), double('NaN')) --- !query 44 schema +-- !query 40 schema struct --- !query 44 output +-- !query 40 output NaN --- !query 45 +-- !query 41 SELECT power(double('NaN'), double('0')) --- !query 45 schema +-- !query 41 schema struct --- !query 45 output +-- !query 41 output 1.0 --- !query 46 +-- !query 42 SELECT '' AS three, f.f1, exp(ln(f.f1)) AS exp_ln_f1 FROM FLOAT8_TBL f WHERE f.f1 > '0.0' --- !query 46 schema +-- !query 42 schema struct --- !query 46 output +-- !query 42 output 1.2345678901234E-200 1.2345678901233948E-200 1.2345678901234E200 1.234567890123379E200 1004.3 1004.3000000000004 --- !query 47 +-- !query 43 SELECT '' AS five, * FROM FLOAT8_TBL --- !query 47 schema +-- !query 43 schema struct --- !query 47 output +-- !query 43 output -34.84 0.0 1.2345678901234E-200 @@ -448,22 +416,22 @@ struct 1004.3 --- !query 48 +-- !query 44 CREATE TEMPORARY VIEW UPDATED_FLOAT8_TBL as SELECT CASE WHEN FLOAT8_TBL.f1 > '0.0' THEN FLOAT8_TBL.f1 * '-1' ELSE FLOAT8_TBL.f1 END AS f1 FROM FLOAT8_TBL --- !query 48 schema +-- !query 44 schema struct<> --- !query 48 output +-- !query 44 output --- !query 49 +-- !query 45 SELECT '' AS bad, f.f1 * '1e200' from UPDATED_FLOAT8_TBL f --- !query 49 schema +-- !query 45 schema struct --- !query 49 output +-- !query 45 output -1.0042999999999999E203 -1.2345678901234 -3.484E201 @@ -471,11 +439,11 @@ struct 0.0 --- !query 50 +-- !query 46 SELECT '' AS five, * FROM UPDATED_FLOAT8_TBL --- !query 50 schema +-- !query 46 schema struct --- !query 50 output +-- !query 46 output -1.2345678901234E-200 -1.2345678901234E200 -1004.3 @@ -483,251 +451,251 @@ struct 0.0 --- !query 51 +-- !query 47 SELECT sinh(double('1')) --- !query 51 schema +-- !query 47 schema struct --- !query 51 output +-- !query 47 output 1.1752011936438014 --- !query 52 +-- !query 48 SELECT cosh(double('1')) --- !query 52 schema +-- !query 48 schema struct --- !query 52 output +-- !query 48 output 1.543080634815244 --- !query 53 +-- !query 49 SELECT tanh(double('1')) --- !query 53 schema +-- !query 49 schema struct --- !query 53 output +-- !query 49 output 0.7615941559557649 --- !query 54 +-- !query 50 SELECT asinh(double('1')) --- !query 54 schema +-- !query 50 schema struct --- !query 54 output +-- !query 50 output 0.8813735870195429 --- !query 55 +-- !query 51 SELECT acosh(double('2')) --- !query 55 schema +-- !query 51 schema struct --- !query 55 output +-- !query 51 output 1.3169578969248166 --- !query 56 +-- !query 52 SELECT atanh(double('0.5')) --- !query 56 schema +-- !query 52 schema struct --- !query 56 output +-- !query 52 output 0.5493061443340548 --- !query 57 +-- !query 53 SELECT sinh(double('Infinity')) --- !query 57 schema +-- !query 53 schema struct --- !query 57 output +-- !query 53 output Infinity --- !query 58 +-- !query 54 SELECT sinh(double('-Infinity')) --- !query 58 schema +-- !query 54 schema struct --- !query 58 output +-- !query 54 output -Infinity --- !query 59 +-- !query 55 SELECT sinh(double('NaN')) --- !query 59 schema +-- !query 55 schema struct --- !query 59 output +-- !query 55 output NaN --- !query 60 +-- !query 56 SELECT cosh(double('Infinity')) --- !query 60 schema +-- !query 56 schema struct --- !query 60 output +-- !query 56 output Infinity --- !query 61 +-- !query 57 SELECT cosh(double('-Infinity')) --- !query 61 schema +-- !query 57 schema struct --- !query 61 output +-- !query 57 output Infinity --- !query 62 +-- !query 58 SELECT cosh(double('NaN')) --- !query 62 schema +-- !query 58 schema struct --- !query 62 output +-- !query 58 output NaN --- !query 63 +-- !query 59 SELECT tanh(double('Infinity')) --- !query 63 schema +-- !query 59 schema struct --- !query 63 output +-- !query 59 output 1.0 --- !query 64 +-- !query 60 SELECT tanh(double('-Infinity')) --- !query 64 schema +-- !query 60 schema struct --- !query 64 output +-- !query 60 output -1.0 --- !query 65 +-- !query 61 SELECT tanh(double('NaN')) --- !query 65 schema +-- !query 61 schema struct --- !query 65 output +-- !query 61 output NaN --- !query 66 +-- !query 62 SELECT asinh(double('Infinity')) --- !query 66 schema +-- !query 62 schema struct --- !query 66 output +-- !query 62 output Infinity --- !query 67 +-- !query 63 SELECT asinh(double('-Infinity')) --- !query 67 schema +-- !query 63 schema struct --- !query 67 output +-- !query 63 output -Infinity --- !query 68 +-- !query 64 SELECT asinh(double('NaN')) --- !query 68 schema +-- !query 64 schema struct --- !query 68 output +-- !query 64 output NaN --- !query 69 +-- !query 65 SELECT acosh(double('Infinity')) --- !query 69 schema +-- !query 65 schema struct --- !query 69 output +-- !query 65 output Infinity --- !query 70 +-- !query 66 SELECT acosh(double('-Infinity')) --- !query 70 schema +-- !query 66 schema struct --- !query 70 output +-- !query 66 output NaN --- !query 71 +-- !query 67 SELECT acosh(double('NaN')) --- !query 71 schema +-- !query 67 schema struct --- !query 71 output +-- !query 67 output NaN --- !query 72 +-- !query 68 SELECT atanh(double('Infinity')) --- !query 72 schema +-- !query 68 schema struct --- !query 72 output +-- !query 68 output NaN --- !query 73 +-- !query 69 SELECT atanh(double('-Infinity')) --- !query 73 schema +-- !query 69 schema struct --- !query 73 output +-- !query 69 output NaN --- !query 74 +-- !query 70 SELECT atanh(double('NaN')) --- !query 74 schema +-- !query 70 schema struct --- !query 74 output +-- !query 70 output NaN --- !query 75 +-- !query 71 TRUNCATE TABLE FLOAT8_TBL --- !query 75 schema +-- !query 71 schema struct<> --- !query 75 output +-- !query 71 output --- !query 76 +-- !query 72 INSERT INTO FLOAT8_TBL VALUES (double('0.0')) --- !query 76 schema +-- !query 72 schema struct<> --- !query 76 output +-- !query 72 output --- !query 77 +-- !query 73 INSERT INTO FLOAT8_TBL VALUES (double('-34.84')) --- !query 77 schema +-- !query 73 schema struct<> --- !query 77 output +-- !query 73 output --- !query 78 +-- !query 74 INSERT INTO FLOAT8_TBL VALUES (double('-1004.30')) --- !query 78 schema +-- !query 74 schema struct<> --- !query 78 output +-- !query 74 output --- !query 79 +-- !query 75 INSERT INTO FLOAT8_TBL VALUES (double('-1.2345678901234e+200')) --- !query 79 schema +-- !query 75 schema struct<> --- !query 79 output +-- !query 75 output --- !query 80 +-- !query 76 INSERT INTO FLOAT8_TBL VALUES (double('-1.2345678901234e-200')) --- !query 80 schema +-- !query 76 schema struct<> --- !query 80 output +-- !query 76 output --- !query 81 +-- !query 77 SELECT '' AS five, * FROM FLOAT8_TBL --- !query 81 schema +-- !query 77 schema struct --- !query 81 output +-- !query 77 output -1.2345678901234E-200 -1.2345678901234E200 -1004.3 @@ -735,106 +703,106 @@ struct 0.0 --- !query 82 +-- !query 78 SELECT smallint(double('32767.4')) --- !query 82 schema +-- !query 78 schema struct --- !query 82 output +-- !query 78 output 32767 --- !query 83 +-- !query 79 SELECT smallint(double('32767.6')) --- !query 83 schema +-- !query 79 schema struct --- !query 83 output +-- !query 79 output 32767 --- !query 84 +-- !query 80 SELECT smallint(double('-32768.4')) --- !query 84 schema +-- !query 80 schema struct --- !query 84 output +-- !query 80 output -32768 --- !query 85 +-- !query 81 SELECT smallint(double('-32768.6')) --- !query 85 schema +-- !query 81 schema struct --- !query 85 output +-- !query 81 output -32768 --- !query 86 +-- !query 82 SELECT int(double('2147483647.4')) --- !query 86 schema +-- !query 82 schema struct --- !query 86 output +-- !query 82 output 2147483647 --- !query 87 +-- !query 83 SELECT int(double('2147483647.6')) --- !query 87 schema +-- !query 83 schema struct --- !query 87 output +-- !query 83 output 2147483647 --- !query 88 +-- !query 84 SELECT int(double('-2147483648.4')) --- !query 88 schema +-- !query 84 schema struct --- !query 88 output +-- !query 84 output -2147483648 --- !query 89 +-- !query 85 SELECT int(double('-2147483648.6')) --- !query 89 schema +-- !query 85 schema struct --- !query 89 output +-- !query 85 output -2147483648 --- !query 90 +-- !query 86 SELECT bigint(double('9223372036854773760')) --- !query 90 schema +-- !query 86 schema struct --- !query 90 output +-- !query 86 output 9223372036854773760 --- !query 91 +-- !query 87 SELECT bigint(double('9223372036854775807')) --- !query 91 schema +-- !query 87 schema struct --- !query 91 output +-- !query 87 output 9223372036854775807 --- !query 92 +-- !query 88 SELECT bigint(double('-9223372036854775808.5')) --- !query 92 schema +-- !query 88 schema struct --- !query 92 output +-- !query 88 output -9223372036854775808 --- !query 93 +-- !query 89 SELECT bigint(double('-9223372036854780000')) --- !query 93 schema +-- !query 89 schema struct<> --- !query 93 output +-- !query 89 output java.lang.ArithmeticException Casting -9.22337203685478E18 to long causes overflow --- !query 94 +-- !query 90 DROP TABLE FLOAT8_TBL --- !query 94 schema +-- !query 90 schema struct<> --- !query 94 output +-- !query 90 output diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/text.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/text.sql.out index 311b0eb5a5844..3edc3b0197024 100755 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/text.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/text.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 42 +-- Number of queries: 40 -- !query 0 @@ -60,101 +60,85 @@ struct -- !query 7 -select string('four: ') || 2+2 --- !query 7 schema -struct<(CAST(concat(CAST(four: AS STRING), CAST(2 AS STRING)) AS DOUBLE) + CAST(2 AS DOUBLE)):double> --- !query 7 output -NULL - - --- !query 8 -select 'four: ' || 2+2 --- !query 8 schema -struct<(CAST(concat(four: , CAST(2 AS STRING)) AS DOUBLE) + CAST(2 AS DOUBLE)):double> --- !query 8 output -NULL - - --- !query 9 select 3 || 4.0 --- !query 9 schema +-- !query 7 schema struct --- !query 9 output +-- !query 7 output 34.0 --- !query 10 +-- !query 8 /* * various string functions */ select concat('one') --- !query 10 schema +-- !query 8 schema struct --- !query 10 output +-- !query 8 output one --- !query 11 +-- !query 9 select concat(1,2,3,'hello',true, false, to_date('20100309','yyyyMMdd')) --- !query 11 schema +-- !query 9 schema struct --- !query 11 output +-- !query 9 output 123hellotruefalse2010-03-09 --- !query 12 +-- !query 10 select concat_ws('#','one') --- !query 12 schema +-- !query 10 schema struct --- !query 12 output +-- !query 10 output one --- !query 13 +-- !query 11 select concat_ws('#',1,2,3,'hello',true, false, to_date('20100309','yyyyMMdd')) --- !query 13 schema +-- !query 11 schema struct --- !query 13 output +-- !query 11 output 1#x#x#hello#true#false#x-03-09 --- !query 14 +-- !query 12 select concat_ws(',',10,20,null,30) --- !query 14 schema +-- !query 12 schema struct --- !query 14 output +-- !query 12 output 10,20,30 --- !query 15 +-- !query 13 select concat_ws('',10,20,null,30) --- !query 15 schema +-- !query 13 schema struct --- !query 15 output +-- !query 13 output 102030 --- !query 16 +-- !query 14 select concat_ws(NULL,10,20,null,30) is null --- !query 16 schema +-- !query 14 schema struct<(concat_ws(CAST(NULL AS STRING), CAST(10 AS STRING), CAST(20 AS STRING), NULL, CAST(30 AS STRING)) IS NULL):boolean> --- !query 16 output +-- !query 14 output true --- !query 17 +-- !query 15 select reverse('abcde') --- !query 17 schema +-- !query 15 schema struct --- !query 17 output +-- !query 15 output edcba --- !query 18 +-- !query 16 select i, left('ahoj', i), right('ahoj', i) from range(-5, 6) t(i) order by i --- !query 18 schema +-- !query 16 schema struct --- !query 18 output +-- !query 16 output -5 -4 -3 @@ -168,192 +152,192 @@ struct 5 ahoj ahoj --- !query 19 +-- !query 17 /* * format */ select format_string(NULL) --- !query 19 schema +-- !query 17 schema struct --- !query 19 output +-- !query 17 output NULL --- !query 20 +-- !query 18 select format_string('Hello') --- !query 20 schema +-- !query 18 schema struct --- !query 20 output +-- !query 18 output Hello --- !query 21 +-- !query 19 select format_string('Hello %s', 'World') --- !query 21 schema +-- !query 19 schema struct --- !query 21 output +-- !query 19 output Hello World --- !query 22 +-- !query 20 select format_string('Hello %%') --- !query 22 schema +-- !query 20 schema struct --- !query 22 output +-- !query 20 output Hello % --- !query 23 +-- !query 21 select format_string('Hello %%%%') --- !query 23 schema +-- !query 21 schema struct --- !query 23 output +-- !query 21 output Hello %% --- !query 24 +-- !query 22 select format_string('Hello %s %s', 'World') --- !query 24 schema +-- !query 22 schema struct<> --- !query 24 output +-- !query 22 output java.util.MissingFormatArgumentException Format specifier '%s' --- !query 25 +-- !query 23 select format_string('Hello %s') --- !query 25 schema +-- !query 23 schema struct<> --- !query 25 output +-- !query 23 output java.util.MissingFormatArgumentException Format specifier '%s' --- !query 26 +-- !query 24 select format_string('Hello %x', 20) --- !query 26 schema +-- !query 24 schema struct --- !query 26 output +-- !query 24 output Hello 14 --- !query 27 +-- !query 25 select format_string('%1$s %3$s', 1, 2, 3) --- !query 27 schema +-- !query 25 schema struct --- !query 27 output +-- !query 25 output 1 3 --- !query 28 +-- !query 26 select format_string('%1$s %12$s', 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) --- !query 28 schema +-- !query 26 schema struct --- !query 28 output +-- !query 26 output 1 12 --- !query 29 +-- !query 27 select format_string('%1$s %4$s', 1, 2, 3) --- !query 29 schema +-- !query 27 schema struct<> --- !query 29 output +-- !query 27 output java.util.MissingFormatArgumentException Format specifier '%4$s' --- !query 30 +-- !query 28 select format_string('%1$s %13$s', 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) --- !query 30 schema +-- !query 28 schema struct<> --- !query 30 output +-- !query 28 output java.util.MissingFormatArgumentException Format specifier '%13$s' --- !query 31 +-- !query 29 select format_string('%0$s', 'Hello') --- !query 31 schema +-- !query 29 schema struct --- !query 31 output +-- !query 29 output Hello --- !query 32 +-- !query 30 select format_string('Hello %s %1$s %s', 'World', 'Hello again') --- !query 32 schema +-- !query 30 schema struct --- !query 32 output +-- !query 30 output Hello World World Hello again --- !query 33 +-- !query 31 select format_string('Hello %s %s, %2$s %2$s', 'World', 'Hello again') --- !query 33 schema +-- !query 31 schema struct --- !query 33 output +-- !query 31 output Hello World Hello again, Hello again Hello again --- !query 34 +-- !query 32 select format_string('>>%10s<<', 'Hello') --- !query 34 schema +-- !query 32 schema struct>%10s<<, Hello):string> --- !query 34 output +-- !query 32 output >> Hello<< --- !query 35 +-- !query 33 select format_string('>>%10s<<', NULL) --- !query 35 schema +-- !query 33 schema struct>%10s<<, NULL):string> --- !query 35 output +-- !query 33 output >> null<< --- !query 36 +-- !query 34 select format_string('>>%10s<<', '') --- !query 36 schema +-- !query 34 schema struct>%10s<<, ):string> --- !query 36 output +-- !query 34 output >> << --- !query 37 +-- !query 35 select format_string('>>%-10s<<', '') --- !query 37 schema +-- !query 35 schema struct>%-10s<<, ):string> --- !query 37 output +-- !query 35 output >> << --- !query 38 +-- !query 36 select format_string('>>%-10s<<', 'Hello') --- !query 38 schema +-- !query 36 schema struct>%-10s<<, Hello):string> --- !query 38 output +-- !query 36 output >>Hello << --- !query 39 +-- !query 37 select format_string('>>%-10s<<', NULL) --- !query 39 schema +-- !query 37 schema struct>%-10s<<, NULL):string> --- !query 39 output +-- !query 37 output >>null << --- !query 40 +-- !query 38 select format_string('>>%1$10s<<', 'Hello') --- !query 40 schema +-- !query 38 schema struct>%1$10s<<, Hello):string> --- !query 40 output +-- !query 38 output >> Hello<< --- !query 41 +-- !query 39 DROP TABLE TEXT_TBL --- !query 41 schema +-- !query 39 schema struct<> --- !query 41 output +-- !query 39 output diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part2.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part2.sql.out index 0d2c78847b97c..bffc3f4f885a8 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part2.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part2.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 31 +-- Number of queries: 30 -- !query 0 @@ -447,33 +447,16 @@ struct +struct<> -- !query 28 output -1 -3 NULL NULL -2 -1 NULL NULL -3 0 NULL NULL -4 1 NULL NULL -5 1 NULL NULL -6 2 NULL NULL -7 100 NULL NULL + -- !query 29 -drop table empsalary +drop table numerics -- !query 29 schema struct<> -- !query 29 output - - --- !query 30 -drop table numerics --- !query 30 schema -struct<> --- !query 30 output - diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part4.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part4.sql.out index 54ceacd3b3b3e..b8a9b32027174 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part4.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part4.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 39 +-- Number of queries: 38 -- !query 0 @@ -491,17 +491,3 @@ struct --- !query 38 output -1 1 1 -2 2 3 -3 NULL 2 -4 3 3 -5 4 7 From f46181dedc9caf1b2df65b3b3fb747ec122d0d49 Mon Sep 17 00:00:00 2001 From: root1 Date: Sat, 4 Jan 2020 12:39:38 +0530 Subject: [PATCH 07/15] Fix --- .../spark/sql/catalyst/expressions/Cast.scala | 42 ++++++++++++------- .../sql/catalyst/expressions/CastSuite.scala | 20 ++++----- 2 files changed, 36 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 08fb8091468a6..97f54ff7c4c12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -484,8 +484,13 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit private[this] def castToLong(from: DataType): Any => Any = from match { case StringType if ansiEnabled => val result = new LongWrapper() - buildCast[UTF8String](_, s => if (s.toLong(result)) result.value - else throw new IllegalArgumentException(s"invalid input syntax for type numeric: $s")) + buildCast[UTF8String](_, s => { + if (s.toLong(result)) { + result.value + } else { + throw new NumberFormatException(s"invalid input syntax for type numeric: $s") + } + }) case StringType => val result = new LongWrapper() buildCast[UTF8String](_, s => if (s.toLong(result)) result.value else null) @@ -505,8 +510,13 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit private[this] def castToInt(from: DataType): Any => Any = from match { case StringType if ansiEnabled => val result = new IntWrapper() - buildCast[UTF8String](_, s => if (s.toInt(result)) result.value - else throw new IllegalArgumentException(s"invalid input syntax for type numeric: $s")) + buildCast[UTF8String](_, s => { + if (s.toInt(result)) { + result.value + } else { + throw new NumberFormatException(s"invalid input syntax for type numeric: $s") + } + }) case StringType => val result = new IntWrapper() buildCast[UTF8String](_, s => if (s.toInt(result)) result.value else null) @@ -532,7 +542,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit result.value.toShort } else { if (ansiEnabled) { - throw new IllegalArgumentException(s"invalid input syntax for type numeric: $s") + throw new NumberFormatException(s"invalid input syntax for type numeric: $s") } else { null } @@ -577,7 +587,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit result.value.toByte } else { if (ansiEnabled) { - throw new IllegalArgumentException(s"invalid input syntax for type numeric: $s") + throw new NumberFormatException(s"invalid input syntax for type numeric: $s") } else { null } @@ -654,7 +664,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit } catch { case _: NumberFormatException => if (ansiEnabled) { - throw new IllegalArgumentException(s"invalid input syntax for type numeric: $s") + throw new NumberFormatException(s"invalid input syntax for type numeric: $s") } else { null } @@ -687,7 +697,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit case _: NumberFormatException => val d = Cast.processFloatingPointSpecialLiterals(doubleStr, false) if(d == null) { - throw new IllegalArgumentException(s"invalid input syntax for type numeric: $s") + throw new NumberFormatException(s"invalid input syntax for type numeric: $s") } else { d.asInstanceOf[Double].doubleValue() } @@ -720,7 +730,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit case _: NumberFormatException => val f = Cast.processFloatingPointSpecialLiterals(floatStr, true) if (f == null) { - throw new IllegalArgumentException(s"invalid input syntax for type numeric: $s") + throw new NumberFormatException(s"invalid input syntax for type numeric: $s") } else { f.asInstanceOf[Float].floatValue() } @@ -1181,7 +1191,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit ${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast)} } catch (java.lang.NumberFormatException e) { if ($ansiEnabled) { - throw new IllegalArgumentException("invalid input syntax for type numeric: $c"); + throw new NumberFormatException("invalid input syntax for type numeric: $c"); } else { $evNull =true; } @@ -1417,7 +1427,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit $evPrim = (byte) $wrapper.value; } else { if ($ansiEnabled) { - throw new IllegalArgumentException("invalid input syntax for type numeric: $c"); + throw new NumberFormatException("invalid input syntax for type numeric: $c"); } else { $evNull = true; } @@ -1452,7 +1462,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit $evPrim = (short) $wrapper.value; } else { if ($ansiEnabled) { - throw new IllegalArgumentException("invalid input syntax for type numeric: $c"); + throw new NumberFormatException("invalid input syntax for type numeric: $c"); } else { $evNull = true; } @@ -1486,7 +1496,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit if ($c.toInt($wrapper)) { $evPrim = $wrapper.value; } else { - throw new IllegalArgumentException("invalid input syntax for type numeric: $c"); + throw new NumberFormatException("invalid input syntax for type numeric: $c"); } $wrapper = null; """ @@ -1530,7 +1540,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit $evPrim = $wrapper.value; } else { if ($ansiEnabled) { - throw new IllegalArgumentException("invalid input syntax for type numeric: $c"); + throw new NumberFormatException("invalid input syntax for type numeric: $c"); } else { $evNull = true; } @@ -1564,7 +1574,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit } catch (java.lang.NumberFormatException e) { final Float f = (Float) Cast.processFloatingPointSpecialLiterals($floatStr, true); if (f == null) { - throw new IllegalArgumentException("invalid input syntax for type numeric: $c"); + throw new NumberFormatException("invalid input syntax for type numeric: $c"); } else { $evPrim = f.floatValue(); } @@ -1611,7 +1621,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit } catch (java.lang.NumberFormatException e) { final Double d = (Double) Cast.processFloatingPointSpecialLiterals($doubleStr, false); if (d == null) { - throw new IllegalArgumentException("invalid input syntax for type numeric: $c"); + throw new NumberFormatException("invalid input syntax for type numeric: $c"); } else { $evPrim = d.doubleValue(); } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 4371e8ca4c4bf..97a03a79d28a3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -1138,7 +1138,7 @@ class CastSuite extends CastSuiteBase { assert(Cast.canCast(set.dataType, ArrayType(StringType, false))) } - test("Cast should output null when ANSI is not enabled.") { + test("Cast should output null for invalid strings when ANSI is not enabled.") { withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { checkEvaluation(cast("abdef", DecimalType.USER_DEFAULT), null) checkEvaluation(cast("2012-12-11", DoubleType), null) @@ -1272,27 +1272,27 @@ class AnsiCastSuite extends CastSuiteBase { } } - test("cast from invalid string to numeric should throw IllegalArgumentException") { + test("cast from invalid string to numeric should throw NumberFormatException") { // cast to IntegerType Seq(IntegerType, ShortType, ByteType, LongType).foreach { dataType => val array = Literal.create(Seq("123", "true", "f", null), ArrayType(StringType, containsNull = true)) - checkExceptionInExpression[IllegalArgumentException]( + checkExceptionInExpression[NumberFormatException]( cast(array, ArrayType(dataType, containsNull = true)), "invalid input") - checkExceptionInExpression[IllegalArgumentException]( + checkExceptionInExpression[NumberFormatException]( cast("string", dataType), "invalid input") - checkExceptionInExpression[IllegalArgumentException]( + checkExceptionInExpression[NumberFormatException]( cast("123-string", dataType), "invalid input") - checkExceptionInExpression[IllegalArgumentException]( + checkExceptionInExpression[NumberFormatException]( cast("2020-07-19", dataType), "invalid input") } - Seq(DoubleType, FloatType, DecimalType.USER_DEFAULT).foreach { dataType => - checkExceptionInExpression[IllegalArgumentException]( + Seq(DecimalType.USER_DEFAULT).foreach { dataType => + checkExceptionInExpression[NumberFormatException]( cast("string", dataType), "invalid input") - checkExceptionInExpression[IllegalArgumentException]( + checkExceptionInExpression[NumberFormatException]( cast("123.000.00", dataType), "invalid input") - checkExceptionInExpression[IllegalArgumentException]( + checkExceptionInExpression[NumberFormatException]( cast("abc.com", dataType), "invalid input") } } From d3ffa3c2b4f1c7bc2b098e37f2f8ae0d4a01eff0 Mon Sep 17 00:00:00 2001 From: root1 Date: Sat, 4 Jan 2020 16:09:08 +0530 Subject: [PATCH 08/15] Fix --- .../apache/spark/unsafe/types/UTF8String.java | 13 +++++ .../spark/sql/catalyst/expressions/Cast.scala | 56 ++++--------------- .../sql/catalyst/expressions/CastSuite.scala | 2 +- 3 files changed, 26 insertions(+), 45 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 3754a1a0374a8..f87e9890dcc03 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -1294,6 +1294,19 @@ public boolean toByte(IntWrapper intWrapper) { return false; } + public boolean toIntExact(IntWrapper intWrapper, boolean ansiEnabled, String integralType, UTF8String c) { + if ((integralType.equals("short") && toShort(intWrapper)) || + (integralType.equals("byte") && toByte(intWrapper)) || + (integralType.equals("int") && toInt(intWrapper)) ) { + return true; + } + if (ansiEnabled) { + throw new NumberFormatException("invalid input syntax for type numeric: " + c); + } else { + return false; + } + } + @Override public String toString() { return new String(getBytes(), StandardCharsets.UTF_8); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 97f54ff7c4c12..e5a5f5b367266 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -508,18 +508,15 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit // IntConverter private[this] def castToInt(from: DataType): Any => Any = from match { - case StringType if ansiEnabled => + case StringType => val result = new IntWrapper() buildCast[UTF8String](_, s => { - if (s.toInt(result)) { + if (s.toIntExact(result, ansiEnabled, "int", s)) { result.value } else { - throw new NumberFormatException(s"invalid input syntax for type numeric: $s") + null } }) - case StringType => - val result = new IntWrapper() - buildCast[UTF8String](_, s => if (s.toInt(result)) result.value else null) case BooleanType => buildCast[Boolean](_, b => if (b) 1 else 0) case DateType => @@ -538,14 +535,10 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit private[this] def castToShort(from: DataType): Any => Any = from match { case StringType => val result = new IntWrapper() - buildCast[UTF8String](_, s => if (s.toShort(result)) { + buildCast[UTF8String](_, s => if (s.toIntExact(result, ansiEnabled, "short", s)) { result.value.toShort } else { - if (ansiEnabled) { - throw new NumberFormatException(s"invalid input syntax for type numeric: $s") - } else { - null - } + null }) case BooleanType => buildCast[Boolean](_, b => if (b) 1.toShort else 0.toShort) @@ -583,14 +576,10 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit private[this] def castToByte(from: DataType): Any => Any = from match { case StringType => val result = new IntWrapper() - buildCast[UTF8String](_, s => if (s.toByte(result)) { + buildCast[UTF8String](_, s => if (s.toIntExact(result, ansiEnabled, "byte", s)) { result.value.toByte } else { - if (ansiEnabled) { - throw new NumberFormatException(s"invalid input syntax for type numeric: $s") - } else { - null - } + null }) case BooleanType => buildCast[Boolean](_, b => if (b) 1.toByte else 0.toByte) @@ -1423,14 +1412,10 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit (c, evPrim, evNull) => code""" UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); - if ($c.toByte($wrapper)) { + if ($c.toIntExact($wrapper, $ansiEnabled, "byte", $c)) { $evPrim = (byte) $wrapper.value; } else { - if ($ansiEnabled) { - throw new NumberFormatException("invalid input syntax for type numeric: $c"); - } else { - $evNull = true; - } + $evNull = true; } $wrapper = null; """ @@ -1458,14 +1443,10 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit (c, evPrim, evNull) => code""" UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); - if ($c.toShort($wrapper)) { + if ($c.toIntExact($wrapper, $ansiEnabled, "short", $c)) { $evPrim = (short) $wrapper.value; } else { - if ($ansiEnabled) { - throw new NumberFormatException("invalid input syntax for type numeric: $c"); - } else { - $evNull = true; - } + $evNull = true; } $wrapper = null; """ @@ -1488,24 +1469,12 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit private[this] def castToIntCode( from: DataType, ctx: CodegenContext): CastFunction = from match { - case StringType if ansiEnabled => - val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper]) - (c, evPrim, evNull) => - code""" - UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); - if ($c.toInt($wrapper)) { - $evPrim = $wrapper.value; - } else { - throw new NumberFormatException("invalid input syntax for type numeric: $c"); - } - $wrapper = null; - """ case StringType => val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper]) (c, evPrim, evNull) => code""" UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); - if ($c.toInt($wrapper)) { + if ($c.toIntExact($wrapper, $ansiEnabled, "int", $c)) { $evPrim = $wrapper.value; } else { $evNull = true; @@ -1532,7 +1501,6 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit ctx: CodegenContext): CastFunction = from match { case StringType => val wrapper = ctx.freshVariable("longWrapper", classOf[UTF8String.LongWrapper]) - (c, evPrim, evNull) => code""" UTF8String.LongWrapper $wrapper = new UTF8String.LongWrapper(); diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 97a03a79d28a3..7418870d9c731 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -1287,7 +1287,7 @@ class AnsiCastSuite extends CastSuiteBase { cast("2020-07-19", dataType), "invalid input") } - Seq(DecimalType.USER_DEFAULT).foreach { dataType => + Seq(DoubleType, FloatType, DecimalType.USER_DEFAULT).foreach { dataType => checkExceptionInExpression[NumberFormatException]( cast("string", dataType), "invalid input") checkExceptionInExpression[NumberFormatException]( From c7dbeef3f538378edc67c0a45de267cd5fe125f6 Mon Sep 17 00:00:00 2001 From: root1 Date: Mon, 6 Jan 2020 17:39:08 +0530 Subject: [PATCH 09/15] Fix --- .../apache/spark/unsafe/types/UTF8String.java | 170 +++++++++++++++++- .../spark/sql/catalyst/expressions/Cast.scala | 97 +++++----- 2 files changed, 219 insertions(+), 48 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index f87e9890dcc03..be291d1437e82 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -1294,17 +1294,171 @@ public boolean toByte(IntWrapper intWrapper) { return false; } - public boolean toIntExact(IntWrapper intWrapper, boolean ansiEnabled, String integralType, UTF8String c) { - if ((integralType.equals("short") && toShort(intWrapper)) || - (integralType.equals("byte") && toByte(intWrapper)) || - (integralType.equals("int") && toInt(intWrapper)) ) { - return true; + /** + * Parses this UTF8String(trimmed if needed) to long. + * + * This method is almost similar to `toLong` defined above. It is used for parsing the UTF8String + * when ANSI mode is enabled. + * @return If string contains valid numeric value then it returns the long value otherwise a + * NumberFormatException is thrown. + */ + + public long toLongExact() { + int offset = 0; + while (offset < this.numBytes && getByte(offset) <= ' ') offset++; + if (offset == this.numBytes) { + throw new NumberFormatException("invalid input syntax for type numeric: " + this); + } + + int end = this.numBytes - 1; + while (end > offset && getByte(end) <= ' ') end--; + + byte b = getByte(offset); + final boolean negative = b == '-'; + if (negative || b == '+') { + if (end - offset == 0) { + throw new NumberFormatException("invalid input syntax for type numeric: " + this); + } + offset++; + } + + final byte separator = '.'; + final int radix = 10; + final long stopValue = Long.MIN_VALUE / radix; + long result = 0; + + while (offset <= end) { + b = getByte(offset); + offset++; + if (b == separator) { + break; } - if (ansiEnabled) { - throw new NumberFormatException("invalid input syntax for type numeric: " + c); + + int digit; + if (b >= '0' && b <= '9') { + digit = b - '0'; } else { - return false; + throw new NumberFormatException("invalid input syntax for type numeric: " + this); + } + + if (result < stopValue) { + throw new NumberFormatException("invalid input syntax for type numeric: " + this); } + + result = result * radix - digit; + if (result > 0) { + throw new NumberFormatException("invalid input syntax for type numeric: " + this); + } + } + + while (offset <= end) { + byte currentByte = getByte(offset); + if (currentByte < '0' || currentByte > '9') { + throw new NumberFormatException("invalid input syntax for type numeric: " + this); + } + offset++; + } + + if (!negative) { + result = -result; + if (result < 0) { + throw new NumberFormatException("invalid input syntax for type numeric: " + this); + } + } + + return result; + } + + /** + * Parses this UTF8String(trimmed if needed) to Int. + * + * This method is almost similar to `toInt` defined above. It is used for parsing the UTF8String + * when ANSI mode is enabled. + * @return If string contains valid numeric value then it returns the int value otherwise a + * NumberFormatException is thrown. + */ + + public int toIntExact() { + int offset = 0; + while (offset < this.numBytes && getByte(offset) <= ' ') offset++; + if (offset == this.numBytes) { + throw new NumberFormatException("invalid input syntax for type numeric: " + this); + } + + int end = this.numBytes - 1; + while (end > offset && getByte(end) <= ' ') end--; + + byte b = getByte(offset); + final boolean negative = b == '-'; + if (negative || b == '+') { + if (end - offset == 0) { + throw new NumberFormatException("invalid input syntax for type numeric: " + this); + } + offset++; + } + + final byte separator = '.'; + final int radix = 10; + final int stopValue = Integer.MIN_VALUE / radix; + int result = 0; + + while (offset <= end) { + b = getByte(offset); + offset++; + if (b == separator) { + break; + } + + int digit; + if (b >= '0' && b <= '9') { + digit = b - '0'; + } else { + throw new NumberFormatException("invalid input syntax for type numeric: " + this); + } + + if (result < stopValue) { + throw new NumberFormatException("invalid input syntax for type numeric: " + this); + } + + result = result * radix - digit; + if (result > 0) { + throw new NumberFormatException("invalid input syntax for type numeric: " + this); + } + } + + while (offset <= end) { + byte currentByte = getByte(offset); + if (currentByte < '0' || currentByte > '9') { + throw new NumberFormatException("invalid input syntax for type numeric: " + this); + } + offset++; + } + + if (!negative) { + result = -result; + if (result < 0) { + throw new NumberFormatException("invalid input syntax for type numeric: " + this); + } + } + return result; + } + + public short toShortExact() { + int value = this.toIntExact(); + short result = (short) value; + if (result == value) { + return result; + } + throw new NumberFormatException("invalid input syntax for type numeric: " + this); + } + + public byte toByteExact() { + int value = this.toIntExact(); + byte result = (byte) value; + if (result == value) { + return result; + } + throw new NumberFormatException("invalid input syntax for type numeric: " + this); } @Override diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index e5a5f5b367266..57bf3be46c624 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -483,14 +483,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit // LongConverter private[this] def castToLong(from: DataType): Any => Any = from match { case StringType if ansiEnabled => - val result = new LongWrapper() - buildCast[UTF8String](_, s => { - if (s.toLong(result)) { - result.value - } else { - throw new NumberFormatException(s"invalid input syntax for type numeric: $s") - } - }) + buildCast[UTF8String](_, _.toLongExact()) case StringType => val result = new LongWrapper() buildCast[UTF8String](_, s => if (s.toLong(result)) result.value else null) @@ -508,15 +501,11 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit // IntConverter private[this] def castToInt(from: DataType): Any => Any = from match { + case StringType if ansiEnabled => + buildCast[UTF8String](_, _.toIntExact()) case StringType => val result = new IntWrapper() - buildCast[UTF8String](_, s => { - if (s.toIntExact(result, ansiEnabled, "int", s)) { - result.value - } else { - null - } - }) + buildCast[UTF8String](_, s => if (s.toInt(result)) result.value else null) case BooleanType => buildCast[Boolean](_, b => if (b) 1 else 0) case DateType => @@ -533,9 +522,11 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit // ShortConverter private[this] def castToShort(from: DataType): Any => Any = from match { + case StringType if ansiEnabled => + buildCast[UTF8String](_, _.toShortExact()) case StringType => val result = new IntWrapper() - buildCast[UTF8String](_, s => if (s.toIntExact(result, ansiEnabled, "short", s)) { + buildCast[UTF8String](_, s => if (s.toShort(result)) { result.value.toShort } else { null @@ -574,9 +565,11 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit // ByteConverter private[this] def castToByte(from: DataType): Any => Any = from match { + case StringType if ansiEnabled => + buildCast[UTF8String](_, _.toByteExact()) case StringType => val result = new IntWrapper() - buildCast[UTF8String](_, s => if (s.toIntExact(result, ansiEnabled, "byte", s)) { + buildCast[UTF8String](_, s => if (s.toByte(result)) { result.value.toByte } else { null @@ -1410,13 +1403,20 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit case StringType => val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper]) (c, evPrim, evNull) => + val casting = if (ansiEnabled) { + s"$evPrim = $c.toByteExact();" + } else { + s""" + if ($c.toByte($wrapper)) { + $evPrim = (byte) $wrapper.value; + } else { + $evNull = true; + } + """ + } code""" UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); - if ($c.toIntExact($wrapper, $ansiEnabled, "byte", $c)) { - $evPrim = (byte) $wrapper.value; - } else { - $evNull = true; - } + $casting $wrapper = null; """ case BooleanType => @@ -1441,13 +1441,20 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit case StringType => val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper]) (c, evPrim, evNull) => + val casting = if (ansiEnabled) { + s"$evPrim = $c.toShortExact();" + } else { + s""" + if ($c.toShort($wrapper)) { + $evPrim = (short) $wrapper.value; + } else { + $evNull = true; + } + """ + } code""" UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); - if ($c.toIntExact($wrapper, $ansiEnabled, "short", $c)) { - $evPrim = (short) $wrapper.value; - } else { - $evNull = true; - } + $casting $wrapper = null; """ case BooleanType => @@ -1472,13 +1479,20 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit case StringType => val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper]) (c, evPrim, evNull) => + val casting = if (ansiEnabled) { + s"$evPrim = $c.toIntExact();" + } else { + s""" + if ($c.toInt($wrapper)) { + $evPrim = $wrapper.value; + } else { + $evNull = true; + } + """ + } code""" UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); - if ($c.toIntExact($wrapper, $ansiEnabled, "int", $c)) { - $evPrim = $wrapper.value; - } else { - $evNull = true; - } + $casting $wrapper = null; """ case BooleanType => @@ -1502,17 +1516,20 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit case StringType => val wrapper = ctx.freshVariable("longWrapper", classOf[UTF8String.LongWrapper]) (c, evPrim, evNull) => - code""" - UTF8String.LongWrapper $wrapper = new UTF8String.LongWrapper(); - if ($c.toLong($wrapper)) { - $evPrim = $wrapper.value; - } else { - if ($ansiEnabled) { - throw new NumberFormatException("invalid input syntax for type numeric: $c"); + val casting = if (ansiEnabled) { + s"$evPrim = $c.toLongExact();" + } else { + s""" + if ($c.toLong($wrapper)) { + $evPrim = $wrapper.value; } else { $evNull = true; } - } + """ + } + code""" + UTF8String.LongWrapper $wrapper = new UTF8String.LongWrapper(); + $casting $wrapper = null; """ case BooleanType => From 7d0faa6636db5f8fd6ef3ff848e200d74276f2a3 Mon Sep 17 00:00:00 2001 From: root1 Date: Mon, 6 Jan 2020 17:49:36 +0530 Subject: [PATCH 10/15] Fix --- .../apache/spark/sql/catalyst/expressions/Cast.scala | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 57bf3be46c624..ba28ba350220a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -1397,9 +1397,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit """ } - private[this] def castToByteCode( - from: DataType, - ctx: CodegenContext): CastFunction = from match { + private[this] def castToByteCode(from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper]) (c, evPrim, evNull) => @@ -1473,9 +1471,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit (c, evPrim, evNull) => code"$evPrim = (short) $c;" } - private[this] def castToIntCode( - from: DataType, - ctx: CodegenContext): CastFunction = from match { + private[this] def castToIntCode(from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper]) (c, evPrim, evNull) => @@ -1510,9 +1506,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit (c, evPrim, evNull) => code"$evPrim = (int) $c;" } - private[this] def castToLongCode( - from: DataType, - ctx: CodegenContext): CastFunction = from match { + private[this] def castToLongCode(from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => val wrapper = ctx.freshVariable("longWrapper", classOf[UTF8String.LongWrapper]) (c, evPrim, evNull) => From d45445227675eb6ee9bcb29d2cb85ee13af98578 Mon Sep 17 00:00:00 2001 From: root1 Date: Mon, 6 Jan 2020 20:01:45 +0530 Subject: [PATCH 11/15] Fix --- .../apache/spark/unsafe/types/UTF8String.java | 141 ++---------------- 1 file changed, 10 insertions(+), 131 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index be291d1437e82..c5384669eb922 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -1295,152 +1295,31 @@ public boolean toByte(IntWrapper intWrapper) { } /** - * Parses this UTF8String(trimmed if needed) to long. + * Parses UTF8String(trimmed if needed) to long. This method is used when ANSI is enabled. * - * This method is almost similar to `toLong` defined above. It is used for parsing the UTF8String - * when ANSI mode is enabled. * @return If string contains valid numeric value then it returns the long value otherwise a * NumberFormatException is thrown. */ - public long toLongExact() { - int offset = 0; - while (offset < this.numBytes && getByte(offset) <= ' ') offset++; - if (offset == this.numBytes) { - throw new NumberFormatException("invalid input syntax for type numeric: " + this); - } - - int end = this.numBytes - 1; - while (end > offset && getByte(end) <= ' ') end--; - - byte b = getByte(offset); - final boolean negative = b == '-'; - if (negative || b == '+') { - if (end - offset == 0) { - throw new NumberFormatException("invalid input syntax for type numeric: " + this); - } - offset++; + LongWrapper result = new LongWrapper(); + if (toLong(result)) { + return result.value; } - - final byte separator = '.'; - final int radix = 10; - final long stopValue = Long.MIN_VALUE / radix; - long result = 0; - - while (offset <= end) { - b = getByte(offset); - offset++; - if (b == separator) { - break; - } - - int digit; - if (b >= '0' && b <= '9') { - digit = b - '0'; - } else { - throw new NumberFormatException("invalid input syntax for type numeric: " + this); - } - - if (result < stopValue) { - throw new NumberFormatException("invalid input syntax for type numeric: " + this); - } - - result = result * radix - digit; - if (result > 0) { - throw new NumberFormatException("invalid input syntax for type numeric: " + this); - } - } - - while (offset <= end) { - byte currentByte = getByte(offset); - if (currentByte < '0' || currentByte > '9') { - throw new NumberFormatException("invalid input syntax for type numeric: " + this); - } - offset++; - } - - if (!negative) { - result = -result; - if (result < 0) { - throw new NumberFormatException("invalid input syntax for type numeric: " + this); - } - } - - return result; + throw new NumberFormatException("invalid input syntax for type numeric: " + this); } /** - * Parses this UTF8String(trimmed if needed) to Int. + * Parses UTF8String(trimmed if needed) to int. This method is used when ANSI is enabled. * - * This method is almost similar to `toInt` defined above. It is used for parsing the UTF8String - * when ANSI mode is enabled. * @return If string contains valid numeric value then it returns the int value otherwise a * NumberFormatException is thrown. */ - public int toIntExact() { - int offset = 0; - while (offset < this.numBytes && getByte(offset) <= ' ') offset++; - if (offset == this.numBytes) { - throw new NumberFormatException("invalid input syntax for type numeric: " + this); - } - - int end = this.numBytes - 1; - while (end > offset && getByte(end) <= ' ') end--; - - byte b = getByte(offset); - final boolean negative = b == '-'; - if (negative || b == '+') { - if (end - offset == 0) { - throw new NumberFormatException("invalid input syntax for type numeric: " + this); - } - offset++; + IntWrapper result = new IntWrapper(); + if (toInt(result)) { + return result.value; } - - final byte separator = '.'; - final int radix = 10; - final int stopValue = Integer.MIN_VALUE / radix; - int result = 0; - - while (offset <= end) { - b = getByte(offset); - offset++; - if (b == separator) { - break; - } - - int digit; - if (b >= '0' && b <= '9') { - digit = b - '0'; - } else { - throw new NumberFormatException("invalid input syntax for type numeric: " + this); - } - - if (result < stopValue) { - throw new NumberFormatException("invalid input syntax for type numeric: " + this); - } - - result = result * radix - digit; - if (result > 0) { - throw new NumberFormatException("invalid input syntax for type numeric: " + this); - } - } - - while (offset <= end) { - byte currentByte = getByte(offset); - if (currentByte < '0' || currentByte > '9') { - throw new NumberFormatException("invalid input syntax for type numeric: " + this); - } - offset++; - } - - if (!negative) { - result = -result; - if (result < 0) { - throw new NumberFormatException("invalid input syntax for type numeric: " + this); - } - } - return result; + throw new NumberFormatException("invalid input syntax for type numeric: " + this); } public short toShortExact() { From 4b0149c7eb13d9cfe3d3c2261820a502546190f7 Mon Sep 17 00:00:00 2001 From: iRakson Date: Wed, 8 Jan 2020 14:52:11 +0530 Subject: [PATCH 12/15] Fix --- .../sql-tests/inputs/postgreSQL/float4.sql | 8 +- .../sql-tests/inputs/postgreSQL/float8.sql | 8 +- .../sql-tests/inputs/postgreSQL/text.sql | 4 +- .../inputs/postgreSQL/window_part2.sql | 8 +- .../inputs/postgreSQL/window_part4.sql | 6 +- .../results/postgreSQL/float4.sql.out | 204 +++++--- .../results/postgreSQL/float8.sql.out | 492 ++++++++++-------- .../sql-tests/results/postgreSQL/text.sql.out | 216 ++++---- .../results/postgreSQL/window_part2.sql.out | 20 +- .../results/postgreSQL/window_part4.sql.out | 13 +- 10 files changed, 546 insertions(+), 433 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/float4.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/float4.sql index 6ddb74f23fb7c..2989569e219ff 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/float4.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/float4.sql @@ -46,14 +46,14 @@ SELECT float('infinity'); SELECT float(' -INFINiTY '); -- [SPARK-27923] Spark SQL insert there bad special inputs to NULL -- bad special inputs --- SELECT float('N A N'); --- SELECT float('NaN x'); --- SELECT float(' INFINITY x'); +SELECT float('N A N'); +SELECT float('NaN x'); +SELECT float(' INFINITY x'); SELECT float('Infinity') + 100.0; SELECT float('Infinity') / float('Infinity'); SELECT float('nan') / float('nan'); --- SELECT float(decimal('nan')); +SELECT float(decimal('nan')); SELECT '' AS five, * FROM FLOAT4_TBL; diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/float8.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/float8.sql index fabdcb0dce483..932cdb95fcf3a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/float8.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/float8.sql @@ -45,15 +45,15 @@ SELECT double('infinity'); SELECT double(' -INFINiTY '); -- [SPARK-27923] Spark SQL insert there bad special inputs to NULL -- bad special inputs --- SELECT double('N A N'); --- SELECT double('NaN x'); --- SELECT double(' INFINITY x'); +SELECT double('N A N'); +SELECT double('NaN x'); +SELECT double(' INFINITY x'); SELECT double('Infinity') + 100.0; SELECT double('Infinity') / double('Infinity'); SELECT double('NaN') / double('NaN'); -- [SPARK-28315] Decimal can not accept NaN as input --- SELECT double(decimal('nan')); +SELECT double(decimal('nan')); SELECT '' AS five, * FROM FLOAT8_TBL; diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/text.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/text.sql index 6e56485d1a8ef..05953123da86f 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/text.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/text.sql @@ -25,8 +25,8 @@ select length(42); -- casting to text in concatenations, so long as the other input is text or -- an unknown literal. So these work: -- [SPARK-28033] String concatenation low priority than other arithmeticBinary --- select string('four: ') || 2+2; --- select 'four: ' || 2+2; +select string('four: ') || 2+2; +select 'four: ' || 2+2; -- but not this: -- Spark SQL implicit cast both side to string diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part2.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part2.sql index 62caf2378a50b..8f1832e75adba 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part2.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part2.sql @@ -281,10 +281,10 @@ from numerics window w as (order by f_numeric range between 1 preceding and 1.1 following); -- currently unsupported --- select id, f_numeric, first(id) over w, last(id) over w --- from numerics --- window w as (order by f_numeric range between --- 1.1 preceding and 'NaN' following); -- error, NaN disallowed +select id, f_numeric, first(id) over w, last(id) over w +from numerics +window w as (order by f_numeric range between + 1.1 preceding and 'NaN' following); -- error, NaN disallowed drop table empsalary; drop table numerics; diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part4.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part4.sql index 653231f3cc87c..932d9558761d2 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part4.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part4.sql @@ -368,9 +368,9 @@ SELECT i,SUM(v) OVER (ORDER BY i ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) -- [SPARK-29638] Spark handles 'NaN' as 0 in sums -- ensure aggregate over numeric properly recovers from NaN values --- SELECT a, b, --- SUM(b) OVER(ORDER BY A ROWS BETWEEN 1 PRECEDING AND CURRENT ROW) --- FROM (VALUES(1,1),(2,2),(3,(cast('nan' as int))),(4,3),(5,4)) t(a,b); +SELECT a, b, + SUM(b) OVER(ORDER BY A ROWS BETWEEN 1 PRECEDING AND CURRENT ROW) +FROM (VALUES(1,1),(2,2),(3,(cast('nan' as int))),(4,3),(5,4)) t(a,b); -- It might be tempting for someone to add an inverse trans function for -- float and double precision. This should not be done as it can give incorrect diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/float4.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/float4.sql.out index 82795de1a0782..7ca4c7eb8aa30 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/float4.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/float4.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 39 +-- Number of queries: 43 -- !query 0 @@ -91,34 +91,70 @@ struct -- !query 11 -SELECT float('Infinity') + 100.0 +SELECT float('N A N') -- !query 11 schema -struct<(CAST(CAST(Infinity AS FLOAT) AS DOUBLE) + CAST(100.0 AS DOUBLE)):double> +struct<> -- !query 11 output -Infinity +java.lang.NumberFormatException +invalid input syntax for type numeric: N A N -- !query 12 -SELECT float('Infinity') / float('Infinity') +SELECT float('NaN x') -- !query 12 schema -struct<(CAST(CAST(Infinity AS FLOAT) AS DOUBLE) / CAST(CAST(Infinity AS FLOAT) AS DOUBLE)):double> +struct<> -- !query 12 output -NaN +java.lang.NumberFormatException +invalid input syntax for type numeric: NaN x -- !query 13 -SELECT float('nan') / float('nan') +SELECT float(' INFINITY x') -- !query 13 schema -struct<(CAST(CAST(nan AS FLOAT) AS DOUBLE) / CAST(CAST(nan AS FLOAT) AS DOUBLE)):double> +struct<> -- !query 13 output -NaN +java.lang.NumberFormatException +invalid input syntax for type numeric: INFINITY x -- !query 14 -SELECT '' AS five, * FROM FLOAT4_TBL +SELECT float('Infinity') + 100.0 -- !query 14 schema -struct +struct<(CAST(CAST(Infinity AS FLOAT) AS DOUBLE) + CAST(100.0 AS DOUBLE)):double> -- !query 14 output +Infinity + + +-- !query 15 +SELECT float('Infinity') / float('Infinity') +-- !query 15 schema +struct<(CAST(CAST(Infinity AS FLOAT) AS DOUBLE) / CAST(CAST(Infinity AS FLOAT) AS DOUBLE)):double> +-- !query 15 output +NaN + + +-- !query 16 +SELECT float('nan') / float('nan') +-- !query 16 schema +struct<(CAST(CAST(nan AS FLOAT) AS DOUBLE) / CAST(CAST(nan AS FLOAT) AS DOUBLE)):double> +-- !query 16 output +NaN + + +-- !query 17 +SELECT float(decimal('nan')) +-- !query 17 schema +struct<> +-- !query 17 output +java.lang.NumberFormatException +invalid input syntax for type numeric: nan + + +-- !query 18 +SELECT '' AS five, * FROM FLOAT4_TBL +-- !query 18 schema +struct +-- !query 18 output -34.84 0.0 1.2345679E-20 @@ -126,116 +162,116 @@ struct 1004.3 --- !query 15 +-- !query 19 SELECT '' AS four, f.* FROM FLOAT4_TBL f WHERE f.f1 <> '1004.3' --- !query 15 schema +-- !query 19 schema struct --- !query 15 output +-- !query 19 output -34.84 0.0 1.2345679E-20 1.2345679E20 --- !query 16 +-- !query 20 SELECT '' AS one, f.* FROM FLOAT4_TBL f WHERE f.f1 = '1004.3' --- !query 16 schema +-- !query 20 schema struct --- !query 16 output +-- !query 20 output 1004.3 --- !query 17 +-- !query 21 SELECT '' AS three, f.* FROM FLOAT4_TBL f WHERE '1004.3' > f.f1 --- !query 17 schema +-- !query 21 schema struct --- !query 17 output +-- !query 21 output -34.84 0.0 1.2345679E-20 --- !query 18 +-- !query 22 SELECT '' AS three, f.* FROM FLOAT4_TBL f WHERE f.f1 < '1004.3' --- !query 18 schema +-- !query 22 schema struct --- !query 18 output +-- !query 22 output -34.84 0.0 1.2345679E-20 --- !query 19 +-- !query 23 SELECT '' AS four, f.* FROM FLOAT4_TBL f WHERE '1004.3' >= f.f1 --- !query 19 schema +-- !query 23 schema struct --- !query 19 output +-- !query 23 output -34.84 0.0 1.2345679E-20 1004.3 --- !query 20 +-- !query 24 SELECT '' AS four, f.* FROM FLOAT4_TBL f WHERE f.f1 <= '1004.3' --- !query 20 schema +-- !query 24 schema struct --- !query 20 output +-- !query 24 output -34.84 0.0 1.2345679E-20 1004.3 --- !query 21 +-- !query 25 SELECT '' AS three, f.f1, f.f1 * '-10' AS x FROM FLOAT4_TBL f WHERE f.f1 > '0.0' --- !query 21 schema +-- !query 25 schema struct --- !query 21 output +-- !query 25 output 1.2345679E-20 -1.2345678720289608E-19 1.2345679E20 -1.2345678955701443E21 1004.3 -10042.999877929688 --- !query 22 +-- !query 26 SELECT '' AS three, f.f1, f.f1 + '-10' AS x FROM FLOAT4_TBL f WHERE f.f1 > '0.0' --- !query 22 schema +-- !query 26 schema struct --- !query 22 output +-- !query 26 output 1.2345679E-20 -10.0 1.2345679E20 1.2345678955701443E20 1004.3 994.2999877929688 --- !query 23 +-- !query 27 SELECT '' AS three, f.f1, f.f1 / '-10' AS x FROM FLOAT4_TBL f WHERE f.f1 > '0.0' --- !query 23 schema +-- !query 27 schema struct --- !query 23 output +-- !query 27 output 1.2345679E-20 -1.2345678720289608E-21 1.2345679E20 -1.2345678955701443E19 1004.3 -100.42999877929688 --- !query 24 +-- !query 28 SELECT '' AS three, f.f1, f.f1 - '-10' AS x FROM FLOAT4_TBL f WHERE f.f1 > '0.0' --- !query 24 schema +-- !query 28 schema struct --- !query 24 output +-- !query 28 output 1.2345679E-20 10.0 1.2345679E20 1.2345678955701443E20 1004.3 1014.2999877929688 --- !query 25 +-- !query 29 SELECT '' AS five, * FROM FLOAT4_TBL --- !query 25 schema +-- !query 29 schema struct --- !query 25 output +-- !query 29 output -34.84 0.0 1.2345679E-20 @@ -243,107 +279,107 @@ struct 1004.3 --- !query 26 +-- !query 30 SELECT smallint(float('32767.4')) --- !query 26 schema +-- !query 30 schema struct --- !query 26 output +-- !query 30 output 32767 --- !query 27 +-- !query 31 SELECT smallint(float('32767.6')) --- !query 27 schema +-- !query 31 schema struct --- !query 27 output +-- !query 31 output 32767 --- !query 28 +-- !query 32 SELECT smallint(float('-32768.4')) --- !query 28 schema +-- !query 32 schema struct --- !query 28 output +-- !query 32 output -32768 --- !query 29 +-- !query 33 SELECT smallint(float('-32768.6')) --- !query 29 schema +-- !query 33 schema struct --- !query 29 output +-- !query 33 output -32768 --- !query 30 +-- !query 34 SELECT int(float('2147483520')) --- !query 30 schema +-- !query 34 schema struct --- !query 30 output +-- !query 34 output 2147483520 --- !query 31 +-- !query 35 SELECT int(float('2147483647')) --- !query 31 schema +-- !query 35 schema struct --- !query 31 output +-- !query 35 output 2147483647 --- !query 32 +-- !query 36 SELECT int(float('-2147483648.5')) --- !query 32 schema +-- !query 36 schema struct --- !query 32 output +-- !query 36 output -2147483648 --- !query 33 +-- !query 37 SELECT int(float('-2147483900')) --- !query 33 schema +-- !query 37 schema struct<> --- !query 33 output +-- !query 37 output java.lang.ArithmeticException Casting -2.1474839E9 to int causes overflow --- !query 34 +-- !query 38 SELECT bigint(float('9223369837831520256')) --- !query 34 schema +-- !query 38 schema struct --- !query 34 output +-- !query 38 output 9223369837831520256 --- !query 35 +-- !query 39 SELECT bigint(float('9223372036854775807')) --- !query 35 schema +-- !query 39 schema struct --- !query 35 output +-- !query 39 output 9223372036854775807 --- !query 36 +-- !query 40 SELECT bigint(float('-9223372036854775808.5')) --- !query 36 schema +-- !query 40 schema struct --- !query 36 output +-- !query 40 output -9223372036854775808 --- !query 37 +-- !query 41 SELECT bigint(float('-9223380000000000000')) --- !query 37 schema +-- !query 41 schema struct<> --- !query 37 output +-- !query 41 output java.lang.ArithmeticException Casting -9.22338E18 to int causes overflow --- !query 38 +-- !query 42 DROP TABLE FLOAT4_TBL --- !query 38 schema +-- !query 42 schema struct<> --- !query 38 output +-- !query 42 output diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/float8.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/float8.sql.out index d6742f4c37b36..9d170d2fd8898 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/float8.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/float8.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 91 +-- Number of queries: 95 -- !query 0 @@ -123,34 +123,70 @@ struct -- !query 15 -SELECT double('Infinity') + 100.0 +SELECT double('N A N') -- !query 15 schema -struct<(CAST(Infinity AS DOUBLE) + CAST(100.0 AS DOUBLE)):double> +struct<> -- !query 15 output -Infinity +java.lang.NumberFormatException +invalid input syntax for type numeric: N A N -- !query 16 -SELECT double('Infinity') / double('Infinity') +SELECT double('NaN x') -- !query 16 schema -struct<(CAST(Infinity AS DOUBLE) / CAST(Infinity AS DOUBLE)):double> +struct<> -- !query 16 output -NaN +java.lang.NumberFormatException +invalid input syntax for type numeric: NaN x -- !query 17 -SELECT double('NaN') / double('NaN') +SELECT double(' INFINITY x') -- !query 17 schema -struct<(CAST(NaN AS DOUBLE) / CAST(NaN AS DOUBLE)):double> +struct<> -- !query 17 output -NaN +java.lang.NumberFormatException +invalid input syntax for type numeric: INFINITY x -- !query 18 -SELECT '' AS five, * FROM FLOAT8_TBL +SELECT double('Infinity') + 100.0 -- !query 18 schema -struct +struct<(CAST(Infinity AS DOUBLE) + CAST(100.0 AS DOUBLE)):double> -- !query 18 output +Infinity + + +-- !query 19 +SELECT double('Infinity') / double('Infinity') +-- !query 19 schema +struct<(CAST(Infinity AS DOUBLE) / CAST(Infinity AS DOUBLE)):double> +-- !query 19 output +NaN + + +-- !query 20 +SELECT double('NaN') / double('NaN') +-- !query 20 schema +struct<(CAST(NaN AS DOUBLE) / CAST(NaN AS DOUBLE)):double> +-- !query 20 output +NaN + + +-- !query 21 +SELECT double(decimal('nan')) +-- !query 21 schema +struct<> +-- !query 21 output +java.lang.NumberFormatException +invalid input syntax for type numeric: nan + + +-- !query 22 +SELECT '' AS five, * FROM FLOAT8_TBL +-- !query 22 schema +struct +-- !query 22 output -34.84 0.0 1.2345678901234E-200 @@ -158,121 +194,121 @@ struct 1004.3 --- !query 19 +-- !query 23 SELECT '' AS four, f.* FROM FLOAT8_TBL f WHERE f.f1 <> '1004.3' --- !query 19 schema +-- !query 23 schema struct --- !query 19 output +-- !query 23 output -34.84 0.0 1.2345678901234E-200 1.2345678901234E200 --- !query 20 +-- !query 24 SELECT '' AS one, f.* FROM FLOAT8_TBL f WHERE f.f1 = '1004.3' --- !query 20 schema +-- !query 24 schema struct --- !query 20 output +-- !query 24 output 1004.3 --- !query 21 +-- !query 25 SELECT '' AS three, f.* FROM FLOAT8_TBL f WHERE '1004.3' > f.f1 --- !query 21 schema +-- !query 25 schema struct --- !query 21 output +-- !query 25 output -34.84 0.0 1.2345678901234E-200 --- !query 22 +-- !query 26 SELECT '' AS three, f.* FROM FLOAT8_TBL f WHERE f.f1 < '1004.3' --- !query 22 schema +-- !query 26 schema struct --- !query 22 output +-- !query 26 output -34.84 0.0 1.2345678901234E-200 --- !query 23 +-- !query 27 SELECT '' AS four, f.* FROM FLOAT8_TBL f WHERE '1004.3' >= f.f1 --- !query 23 schema +-- !query 27 schema struct --- !query 23 output +-- !query 27 output -34.84 0.0 1.2345678901234E-200 1004.3 --- !query 24 +-- !query 28 SELECT '' AS four, f.* FROM FLOAT8_TBL f WHERE f.f1 <= '1004.3' --- !query 24 schema +-- !query 28 schema struct --- !query 24 output +-- !query 28 output -34.84 0.0 1.2345678901234E-200 1004.3 --- !query 25 +-- !query 29 SELECT '' AS three, f.f1, f.f1 * '-10' AS x FROM FLOAT8_TBL f WHERE f.f1 > '0.0' --- !query 25 schema +-- !query 29 schema struct --- !query 25 output +-- !query 29 output 1.2345678901234E-200 -1.2345678901234E-199 1.2345678901234E200 -1.2345678901234E201 1004.3 -10043.0 --- !query 26 +-- !query 30 SELECT '' AS three, f.f1, f.f1 + '-10' AS x FROM FLOAT8_TBL f WHERE f.f1 > '0.0' --- !query 26 schema +-- !query 30 schema struct --- !query 26 output +-- !query 30 output 1.2345678901234E-200 -10.0 1.2345678901234E200 1.2345678901234E200 1004.3 994.3 --- !query 27 +-- !query 31 SELECT '' AS three, f.f1, f.f1 / '-10' AS x FROM FLOAT8_TBL f WHERE f.f1 > '0.0' --- !query 27 schema +-- !query 31 schema struct --- !query 27 output +-- !query 31 output 1.2345678901234E-200 -1.2345678901234E-201 1.2345678901234E200 -1.2345678901234E199 1004.3 -100.42999999999999 --- !query 28 +-- !query 32 SELECT '' AS three, f.f1, f.f1 - '-10' AS x FROM FLOAT8_TBL f WHERE f.f1 > '0.0' --- !query 28 schema +-- !query 32 schema struct --- !query 28 output +-- !query 32 output 1.2345678901234E-200 10.0 1.2345678901234E200 1.2345678901234E200 1004.3 1014.3 --- !query 29 +-- !query 33 SELECT '' AS five, f.f1, round(f.f1) AS round_f1 FROM FLOAT8_TBL f --- !query 29 schema +-- !query 33 schema struct --- !query 29 output +-- !query 33 output -34.84 -35.0 0.0 0.0 1.2345678901234E-200 0.0 @@ -280,11 +316,11 @@ struct 1004.3 1004.0 --- !query 30 +-- !query 34 select ceil(f1) as ceil_f1 from float8_tbl f --- !query 30 schema +-- !query 34 schema struct --- !query 30 output +-- !query 34 output -34 0 1 @@ -292,11 +328,11 @@ struct 9223372036854775807 --- !query 31 +-- !query 35 select ceiling(f1) as ceiling_f1 from float8_tbl f --- !query 31 schema +-- !query 35 schema struct --- !query 31 output +-- !query 35 output -34 0 1 @@ -304,11 +340,11 @@ struct 9223372036854775807 --- !query 32 +-- !query 36 select floor(f1) as floor_f1 from float8_tbl f --- !query 32 schema +-- !query 36 schema struct --- !query 32 output +-- !query 36 output -35 0 0 @@ -316,11 +352,11 @@ struct 9223372036854775807 --- !query 33 +-- !query 37 select sign(f1) as sign_f1 from float8_tbl f --- !query 33 schema +-- !query 37 schema struct --- !query 33 output +-- !query 37 output -1.0 0.0 1.0 @@ -328,87 +364,87 @@ struct 1.0 --- !query 34 +-- !query 38 SELECT sqrt(double('64')) AS eight --- !query 34 schema +-- !query 38 schema struct --- !query 34 output +-- !query 38 output 8.0 --- !query 35 +-- !query 39 SELECT power(double('144'), double('0.5')) --- !query 35 schema +-- !query 39 schema struct --- !query 35 output +-- !query 39 output 12.0 --- !query 36 +-- !query 40 SELECT power(double('NaN'), double('0.5')) --- !query 36 schema +-- !query 40 schema struct --- !query 36 output +-- !query 40 output NaN --- !query 37 +-- !query 41 SELECT power(double('144'), double('NaN')) --- !query 37 schema +-- !query 41 schema struct --- !query 37 output +-- !query 41 output NaN --- !query 38 +-- !query 42 SELECT power(double('NaN'), double('NaN')) --- !query 38 schema +-- !query 42 schema struct --- !query 38 output +-- !query 42 output NaN --- !query 39 +-- !query 43 SELECT power(double('-1'), double('NaN')) --- !query 39 schema +-- !query 43 schema struct --- !query 39 output +-- !query 43 output NaN --- !query 40 +-- !query 44 SELECT power(double('1'), double('NaN')) --- !query 40 schema +-- !query 44 schema struct --- !query 40 output +-- !query 44 output NaN --- !query 41 +-- !query 45 SELECT power(double('NaN'), double('0')) --- !query 41 schema +-- !query 45 schema struct --- !query 41 output +-- !query 45 output 1.0 --- !query 42 +-- !query 46 SELECT '' AS three, f.f1, exp(ln(f.f1)) AS exp_ln_f1 FROM FLOAT8_TBL f WHERE f.f1 > '0.0' --- !query 42 schema +-- !query 46 schema struct --- !query 42 output +-- !query 46 output 1.2345678901234E-200 1.2345678901233948E-200 1.2345678901234E200 1.234567890123379E200 1004.3 1004.3000000000004 --- !query 43 +-- !query 47 SELECT '' AS five, * FROM FLOAT8_TBL --- !query 43 schema +-- !query 47 schema struct --- !query 43 output +-- !query 47 output -34.84 0.0 1.2345678901234E-200 @@ -416,22 +452,22 @@ struct 1004.3 --- !query 44 +-- !query 48 CREATE TEMPORARY VIEW UPDATED_FLOAT8_TBL as SELECT CASE WHEN FLOAT8_TBL.f1 > '0.0' THEN FLOAT8_TBL.f1 * '-1' ELSE FLOAT8_TBL.f1 END AS f1 FROM FLOAT8_TBL --- !query 44 schema +-- !query 48 schema struct<> --- !query 44 output +-- !query 48 output --- !query 45 +-- !query 49 SELECT '' AS bad, f.f1 * '1e200' from UPDATED_FLOAT8_TBL f --- !query 45 schema +-- !query 49 schema struct --- !query 45 output +-- !query 49 output -1.0042999999999999E203 -1.2345678901234 -3.484E201 @@ -439,11 +475,11 @@ struct 0.0 --- !query 46 +-- !query 50 SELECT '' AS five, * FROM UPDATED_FLOAT8_TBL --- !query 46 schema +-- !query 50 schema struct --- !query 46 output +-- !query 50 output -1.2345678901234E-200 -1.2345678901234E200 -1004.3 @@ -451,251 +487,251 @@ struct 0.0 --- !query 47 +-- !query 51 SELECT sinh(double('1')) --- !query 47 schema +-- !query 51 schema struct --- !query 47 output +-- !query 51 output 1.1752011936438014 --- !query 48 +-- !query 52 SELECT cosh(double('1')) --- !query 48 schema +-- !query 52 schema struct --- !query 48 output +-- !query 52 output 1.543080634815244 --- !query 49 +-- !query 53 SELECT tanh(double('1')) --- !query 49 schema +-- !query 53 schema struct --- !query 49 output +-- !query 53 output 0.7615941559557649 --- !query 50 +-- !query 54 SELECT asinh(double('1')) --- !query 50 schema +-- !query 54 schema struct --- !query 50 output +-- !query 54 output 0.8813735870195429 --- !query 51 +-- !query 55 SELECT acosh(double('2')) --- !query 51 schema +-- !query 55 schema struct --- !query 51 output +-- !query 55 output 1.3169578969248166 --- !query 52 +-- !query 56 SELECT atanh(double('0.5')) --- !query 52 schema +-- !query 56 schema struct --- !query 52 output +-- !query 56 output 0.5493061443340548 --- !query 53 +-- !query 57 SELECT sinh(double('Infinity')) --- !query 53 schema +-- !query 57 schema struct --- !query 53 output +-- !query 57 output Infinity --- !query 54 +-- !query 58 SELECT sinh(double('-Infinity')) --- !query 54 schema +-- !query 58 schema struct --- !query 54 output +-- !query 58 output -Infinity --- !query 55 +-- !query 59 SELECT sinh(double('NaN')) --- !query 55 schema +-- !query 59 schema struct --- !query 55 output +-- !query 59 output NaN --- !query 56 +-- !query 60 SELECT cosh(double('Infinity')) --- !query 56 schema +-- !query 60 schema struct --- !query 56 output +-- !query 60 output Infinity --- !query 57 +-- !query 61 SELECT cosh(double('-Infinity')) --- !query 57 schema +-- !query 61 schema struct --- !query 57 output +-- !query 61 output Infinity --- !query 58 +-- !query 62 SELECT cosh(double('NaN')) --- !query 58 schema +-- !query 62 schema struct --- !query 58 output +-- !query 62 output NaN --- !query 59 +-- !query 63 SELECT tanh(double('Infinity')) --- !query 59 schema +-- !query 63 schema struct --- !query 59 output +-- !query 63 output 1.0 --- !query 60 +-- !query 64 SELECT tanh(double('-Infinity')) --- !query 60 schema +-- !query 64 schema struct --- !query 60 output +-- !query 64 output -1.0 --- !query 61 +-- !query 65 SELECT tanh(double('NaN')) --- !query 61 schema +-- !query 65 schema struct --- !query 61 output +-- !query 65 output NaN --- !query 62 +-- !query 66 SELECT asinh(double('Infinity')) --- !query 62 schema +-- !query 66 schema struct --- !query 62 output +-- !query 66 output Infinity --- !query 63 +-- !query 67 SELECT asinh(double('-Infinity')) --- !query 63 schema +-- !query 67 schema struct --- !query 63 output +-- !query 67 output -Infinity --- !query 64 +-- !query 68 SELECT asinh(double('NaN')) --- !query 64 schema +-- !query 68 schema struct --- !query 64 output +-- !query 68 output NaN --- !query 65 +-- !query 69 SELECT acosh(double('Infinity')) --- !query 65 schema +-- !query 69 schema struct --- !query 65 output +-- !query 69 output Infinity --- !query 66 +-- !query 70 SELECT acosh(double('-Infinity')) --- !query 66 schema +-- !query 70 schema struct --- !query 66 output +-- !query 70 output NaN --- !query 67 +-- !query 71 SELECT acosh(double('NaN')) --- !query 67 schema +-- !query 71 schema struct --- !query 67 output +-- !query 71 output NaN --- !query 68 +-- !query 72 SELECT atanh(double('Infinity')) --- !query 68 schema +-- !query 72 schema struct --- !query 68 output +-- !query 72 output NaN --- !query 69 +-- !query 73 SELECT atanh(double('-Infinity')) --- !query 69 schema +-- !query 73 schema struct --- !query 69 output +-- !query 73 output NaN --- !query 70 +-- !query 74 SELECT atanh(double('NaN')) --- !query 70 schema +-- !query 74 schema struct --- !query 70 output +-- !query 74 output NaN --- !query 71 +-- !query 75 TRUNCATE TABLE FLOAT8_TBL --- !query 71 schema +-- !query 75 schema struct<> --- !query 71 output +-- !query 75 output --- !query 72 +-- !query 76 INSERT INTO FLOAT8_TBL VALUES (double('0.0')) --- !query 72 schema +-- !query 76 schema struct<> --- !query 72 output +-- !query 76 output --- !query 73 +-- !query 77 INSERT INTO FLOAT8_TBL VALUES (double('-34.84')) --- !query 73 schema +-- !query 77 schema struct<> --- !query 73 output +-- !query 77 output --- !query 74 +-- !query 78 INSERT INTO FLOAT8_TBL VALUES (double('-1004.30')) --- !query 74 schema +-- !query 78 schema struct<> --- !query 74 output +-- !query 78 output --- !query 75 +-- !query 79 INSERT INTO FLOAT8_TBL VALUES (double('-1.2345678901234e+200')) --- !query 75 schema +-- !query 79 schema struct<> --- !query 75 output +-- !query 79 output --- !query 76 +-- !query 80 INSERT INTO FLOAT8_TBL VALUES (double('-1.2345678901234e-200')) --- !query 76 schema +-- !query 80 schema struct<> --- !query 76 output +-- !query 80 output --- !query 77 +-- !query 81 SELECT '' AS five, * FROM FLOAT8_TBL --- !query 77 schema +-- !query 81 schema struct --- !query 77 output +-- !query 81 output -1.2345678901234E-200 -1.2345678901234E200 -1004.3 @@ -703,106 +739,106 @@ struct 0.0 --- !query 78 +-- !query 82 SELECT smallint(double('32767.4')) --- !query 78 schema +-- !query 82 schema struct --- !query 78 output +-- !query 82 output 32767 --- !query 79 +-- !query 83 SELECT smallint(double('32767.6')) --- !query 79 schema +-- !query 83 schema struct --- !query 79 output +-- !query 83 output 32767 --- !query 80 +-- !query 84 SELECT smallint(double('-32768.4')) --- !query 80 schema +-- !query 84 schema struct --- !query 80 output +-- !query 84 output -32768 --- !query 81 +-- !query 85 SELECT smallint(double('-32768.6')) --- !query 81 schema +-- !query 85 schema struct --- !query 81 output +-- !query 85 output -32768 --- !query 82 +-- !query 86 SELECT int(double('2147483647.4')) --- !query 82 schema +-- !query 86 schema struct --- !query 82 output +-- !query 86 output 2147483647 --- !query 83 +-- !query 87 SELECT int(double('2147483647.6')) --- !query 83 schema +-- !query 87 schema struct --- !query 83 output +-- !query 87 output 2147483647 --- !query 84 +-- !query 88 SELECT int(double('-2147483648.4')) --- !query 84 schema +-- !query 88 schema struct --- !query 84 output +-- !query 88 output -2147483648 --- !query 85 +-- !query 89 SELECT int(double('-2147483648.6')) --- !query 85 schema +-- !query 89 schema struct --- !query 85 output +-- !query 89 output -2147483648 --- !query 86 +-- !query 90 SELECT bigint(double('9223372036854773760')) --- !query 86 schema +-- !query 90 schema struct --- !query 86 output +-- !query 90 output 9223372036854773760 --- !query 87 +-- !query 91 SELECT bigint(double('9223372036854775807')) --- !query 87 schema +-- !query 91 schema struct --- !query 87 output +-- !query 91 output 9223372036854775807 --- !query 88 +-- !query 92 SELECT bigint(double('-9223372036854775808.5')) --- !query 88 schema +-- !query 92 schema struct --- !query 88 output +-- !query 92 output -9223372036854775808 --- !query 89 +-- !query 93 SELECT bigint(double('-9223372036854780000')) --- !query 89 schema +-- !query 93 schema struct<> --- !query 89 output +-- !query 93 output java.lang.ArithmeticException Casting -9.22337203685478E18 to long causes overflow --- !query 90 +-- !query 94 DROP TABLE FLOAT8_TBL --- !query 90 schema +-- !query 94 schema struct<> --- !query 90 output +-- !query 94 output diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/text.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/text.sql.out index 3edc3b0197024..6faa5b6924460 100755 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/text.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/text.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 40 +-- Number of queries: 42 -- !query 0 @@ -60,85 +60,103 @@ struct -- !query 7 -select 3 || 4.0 +select string('four: ') || 2+2 -- !query 7 schema -struct +struct<> -- !query 7 output -34.0 +java.lang.NumberFormatException +invalid input syntax for type numeric: four: 2 -- !query 8 +select 'four: ' || 2+2 +-- !query 8 schema +struct<> +-- !query 8 output +java.lang.NumberFormatException +invalid input syntax for type numeric: four: 2 + + +-- !query 9 +select 3 || 4.0 +-- !query 9 schema +struct +-- !query 9 output +34.0 + + +-- !query 10 /* * various string functions */ select concat('one') --- !query 8 schema +-- !query 10 schema struct --- !query 8 output +-- !query 10 output one --- !query 9 +-- !query 11 select concat(1,2,3,'hello',true, false, to_date('20100309','yyyyMMdd')) --- !query 9 schema +-- !query 11 schema struct --- !query 9 output +-- !query 11 output 123hellotruefalse2010-03-09 --- !query 10 +-- !query 12 select concat_ws('#','one') --- !query 10 schema +-- !query 12 schema struct --- !query 10 output +-- !query 12 output one --- !query 11 +-- !query 13 select concat_ws('#',1,2,3,'hello',true, false, to_date('20100309','yyyyMMdd')) --- !query 11 schema +-- !query 13 schema struct --- !query 11 output +-- !query 13 output 1#x#x#hello#true#false#x-03-09 --- !query 12 +-- !query 14 select concat_ws(',',10,20,null,30) --- !query 12 schema +-- !query 14 schema struct --- !query 12 output +-- !query 14 output 10,20,30 --- !query 13 +-- !query 15 select concat_ws('',10,20,null,30) --- !query 13 schema +-- !query 15 schema struct --- !query 13 output +-- !query 15 output 102030 --- !query 14 +-- !query 16 select concat_ws(NULL,10,20,null,30) is null --- !query 14 schema +-- !query 16 schema struct<(concat_ws(CAST(NULL AS STRING), CAST(10 AS STRING), CAST(20 AS STRING), NULL, CAST(30 AS STRING)) IS NULL):boolean> --- !query 14 output +-- !query 16 output true --- !query 15 +-- !query 17 select reverse('abcde') --- !query 15 schema +-- !query 17 schema struct --- !query 15 output +-- !query 17 output edcba --- !query 16 +-- !query 18 select i, left('ahoj', i), right('ahoj', i) from range(-5, 6) t(i) order by i --- !query 16 schema +-- !query 18 schema struct --- !query 16 output +-- !query 18 output -5 -4 -3 @@ -152,192 +170,192 @@ struct 5 ahoj ahoj --- !query 17 +-- !query 19 /* * format */ select format_string(NULL) --- !query 17 schema +-- !query 19 schema struct --- !query 17 output +-- !query 19 output NULL --- !query 18 +-- !query 20 select format_string('Hello') --- !query 18 schema +-- !query 20 schema struct --- !query 18 output +-- !query 20 output Hello --- !query 19 +-- !query 21 select format_string('Hello %s', 'World') --- !query 19 schema +-- !query 21 schema struct --- !query 19 output +-- !query 21 output Hello World --- !query 20 +-- !query 22 select format_string('Hello %%') --- !query 20 schema +-- !query 22 schema struct --- !query 20 output +-- !query 22 output Hello % --- !query 21 +-- !query 23 select format_string('Hello %%%%') --- !query 21 schema +-- !query 23 schema struct --- !query 21 output +-- !query 23 output Hello %% --- !query 22 +-- !query 24 select format_string('Hello %s %s', 'World') --- !query 22 schema +-- !query 24 schema struct<> --- !query 22 output +-- !query 24 output java.util.MissingFormatArgumentException Format specifier '%s' --- !query 23 +-- !query 25 select format_string('Hello %s') --- !query 23 schema +-- !query 25 schema struct<> --- !query 23 output +-- !query 25 output java.util.MissingFormatArgumentException Format specifier '%s' --- !query 24 +-- !query 26 select format_string('Hello %x', 20) --- !query 24 schema +-- !query 26 schema struct --- !query 24 output +-- !query 26 output Hello 14 --- !query 25 +-- !query 27 select format_string('%1$s %3$s', 1, 2, 3) --- !query 25 schema +-- !query 27 schema struct --- !query 25 output +-- !query 27 output 1 3 --- !query 26 +-- !query 28 select format_string('%1$s %12$s', 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) --- !query 26 schema +-- !query 28 schema struct --- !query 26 output +-- !query 28 output 1 12 --- !query 27 +-- !query 29 select format_string('%1$s %4$s', 1, 2, 3) --- !query 27 schema +-- !query 29 schema struct<> --- !query 27 output +-- !query 29 output java.util.MissingFormatArgumentException Format specifier '%4$s' --- !query 28 +-- !query 30 select format_string('%1$s %13$s', 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) --- !query 28 schema +-- !query 30 schema struct<> --- !query 28 output +-- !query 30 output java.util.MissingFormatArgumentException Format specifier '%13$s' --- !query 29 +-- !query 31 select format_string('%0$s', 'Hello') --- !query 29 schema +-- !query 31 schema struct --- !query 29 output +-- !query 31 output Hello --- !query 30 +-- !query 32 select format_string('Hello %s %1$s %s', 'World', 'Hello again') --- !query 30 schema +-- !query 32 schema struct --- !query 30 output +-- !query 32 output Hello World World Hello again --- !query 31 +-- !query 33 select format_string('Hello %s %s, %2$s %2$s', 'World', 'Hello again') --- !query 31 schema +-- !query 33 schema struct --- !query 31 output +-- !query 33 output Hello World Hello again, Hello again Hello again --- !query 32 +-- !query 34 select format_string('>>%10s<<', 'Hello') --- !query 32 schema +-- !query 34 schema struct>%10s<<, Hello):string> --- !query 32 output +-- !query 34 output >> Hello<< --- !query 33 +-- !query 35 select format_string('>>%10s<<', NULL) --- !query 33 schema +-- !query 35 schema struct>%10s<<, NULL):string> --- !query 33 output +-- !query 35 output >> null<< --- !query 34 +-- !query 36 select format_string('>>%10s<<', '') --- !query 34 schema +-- !query 36 schema struct>%10s<<, ):string> --- !query 34 output +-- !query 36 output >> << --- !query 35 +-- !query 37 select format_string('>>%-10s<<', '') --- !query 35 schema +-- !query 37 schema struct>%-10s<<, ):string> --- !query 35 output +-- !query 37 output >> << --- !query 36 +-- !query 38 select format_string('>>%-10s<<', 'Hello') --- !query 36 schema +-- !query 38 schema struct>%-10s<<, Hello):string> --- !query 36 output +-- !query 38 output >>Hello << --- !query 37 +-- !query 39 select format_string('>>%-10s<<', NULL) --- !query 37 schema +-- !query 39 schema struct>%-10s<<, NULL):string> --- !query 37 output +-- !query 39 output >>null << --- !query 38 +-- !query 40 select format_string('>>%1$10s<<', 'Hello') --- !query 38 schema +-- !query 40 schema struct>%1$10s<<, Hello):string> --- !query 38 output +-- !query 40 output >> Hello<< --- !query 39 +-- !query 41 DROP TABLE TEXT_TBL --- !query 39 schema +-- !query 41 schema struct<> --- !query 39 output +-- !query 41 output diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part2.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part2.sql.out index bffc3f4f885a8..bbbdb1aaeab4a 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part2.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part2.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 30 +-- Number of queries: 31 -- !query 0 @@ -447,16 +447,28 @@ struct -- !query 28 output - +java.lang.NumberFormatException +invalid input syntax for type numeric: NaN -- !query 29 -drop table numerics +drop table empsalary -- !query 29 schema struct<> -- !query 29 output + + +-- !query 30 +drop table numerics +-- !query 30 schema +struct<> +-- !query 30 output + diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part4.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part4.sql.out index b8a9b32027174..5588836f83163 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part4.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part4.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 38 +-- Number of queries: 39 -- !query 0 @@ -491,3 +491,14 @@ struct +-- !query 38 output +org.apache.spark.sql.AnalysisException +failed to evaluate expression CAST('nan' AS INT): invalid input syntax for type numeric: nan; line 3 pos 6 From 40afc54d8583ad86d46b3d14f9154c28aec505cf Mon Sep 17 00:00:00 2001 From: iRakson Date: Wed, 8 Jan 2020 16:35:08 +0530 Subject: [PATCH 13/15] Fix --- .../spark/sql/catalyst/expressions/Cast.scala | 138 +++++------------- .../inputs/postgreSQL/window_part2.sql | 2 +- .../inputs/postgreSQL/window_part4.sql | 2 +- .../results/postgreSQL/window_part2.sql.out | 2 +- .../results/postgreSQL/window_part4.sql.out | 2 +- 5 files changed, 44 insertions(+), 102 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index ba28ba350220a..e4674931e2c37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -672,27 +672,19 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit // DoubleConverter private[this] def castToDouble(from: DataType): Any => Any = from match { - case StringType if ansiEnabled => + case StringType => buildCast[UTF8String](_, s => { val doubleStr = s.toString try doubleStr.toDouble catch { case _: NumberFormatException => val d = Cast.processFloatingPointSpecialLiterals(doubleStr, false) - if(d == null) { + if(ansiEnabled && d == null) { throw new NumberFormatException(s"invalid input syntax for type numeric: $s") } else { - d.asInstanceOf[Double].doubleValue() + d } } }) - case StringType => - buildCast[UTF8String](_, s => { - val doubleStr = s.toString - try doubleStr.toDouble catch { - case _: NumberFormatException => - Cast.processFloatingPointSpecialLiterals(doubleStr, false) - } - }) case BooleanType => buildCast[Boolean](_, b => if (b) 1d else 0d) case DateType => @@ -705,27 +697,19 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit // FloatConverter private[this] def castToFloat(from: DataType): Any => Any = from match { - case StringType if ansiEnabled => + case StringType => buildCast[UTF8String](_, s => { val floatStr = s.toString try floatStr.toFloat catch { case _: NumberFormatException => val f = Cast.processFloatingPointSpecialLiterals(floatStr, true) - if (f == null) { + if (ansiEnabled && f == null) { throw new NumberFormatException(s"invalid input syntax for type numeric: $s") } else { - f.asInstanceOf[Float].floatValue() + f } } }) - case StringType => - buildCast[UTF8String](_, s => { - val floatStr = s.toString - try floatStr.toFloat catch { - case _: NumberFormatException => - Cast.processFloatingPointSpecialLiterals(floatStr, true) - } - }) case BooleanType => buildCast[Boolean](_, b => if (b) 1f else 0f) case DateType => @@ -1398,23 +1382,18 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit } private[this] def castToByteCode(from: DataType, ctx: CodegenContext): CastFunction = from match { + case StringType if ansiEnabled => + (c, evPrim, evNull) => code"$evPrim = $c.toByteExact();" case StringType => val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper]) (c, evPrim, evNull) => - val casting = if (ansiEnabled) { - s"$evPrim = $c.toByteExact();" - } else { - s""" - if ($c.toByte($wrapper)) { + code""" + UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); + if ($c.toByte($wrapper)) { $evPrim = (byte) $wrapper.value; } else { $evNull = true; } - """ - } - code""" - UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); - $casting $wrapper = null; """ case BooleanType => @@ -1436,23 +1415,18 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit private[this] def castToShortCode( from: DataType, ctx: CodegenContext): CastFunction = from match { + case StringType if ansiEnabled => + (c, evPrim, evNull) => code"$evPrim = $c.toShortExact();" case StringType => val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper]) (c, evPrim, evNull) => - val casting = if (ansiEnabled) { - s"$evPrim = $c.toShortExact();" - } else { - s""" - if ($c.toShort($wrapper)) { + code""" + UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); + if ($c.toShort($wrapper)) { $evPrim = (short) $wrapper.value; } else { $evNull = true; } - """ - } - code""" - UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); - $casting $wrapper = null; """ case BooleanType => @@ -1472,23 +1446,18 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit } private[this] def castToIntCode(from: DataType, ctx: CodegenContext): CastFunction = from match { + case StringType if ansiEnabled => + (c, evPrim, evNull) => code"$evPrim = $c.toIntExact();" case StringType => val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper]) (c, evPrim, evNull) => - val casting = if (ansiEnabled) { - s"$evPrim = $c.toIntExact();" - } else { - s""" - if ($c.toInt($wrapper)) { + code""" + UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); + if ($c.toInt($wrapper)) { $evPrim = $wrapper.value; } else { $evNull = true; } - """ - } - code""" - UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); - $casting $wrapper = null; """ case BooleanType => @@ -1507,23 +1476,18 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit } private[this] def castToLongCode(from: DataType, ctx: CodegenContext): CastFunction = from match { + case StringType if ansiEnabled => + (c, evPrim, evNull) => code"$evPrim = $c.toLongExact();" case StringType => val wrapper = ctx.freshVariable("longWrapper", classOf[UTF8String.LongWrapper]) (c, evPrim, evNull) => - val casting = if (ansiEnabled) { - s"$evPrim = $c.toLongExact();" - } else { - s""" - if ($c.toLong($wrapper)) { - $evPrim = $wrapper.value; - } else { - $evNull = true; - } - """ - } code""" UTF8String.LongWrapper $wrapper = new UTF8String.LongWrapper(); - $casting + if ($c.toLong($wrapper)) { + $evPrim = $wrapper.value; + } else { + $evNull = true; + } $wrapper = null; """ case BooleanType => @@ -1543,25 +1507,14 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit private[this] def castToFloatCode(from: DataType, ctx: CodegenContext): CastFunction = { from match { - case StringType if ansiEnabled => - val floatStr = ctx.freshVariable("floatStr", StringType) - (c, evPrim, evNull) => - code""" - final String $floatStr = $c.toString(); - try { - $evPrim = Float.valueOf($floatStr); - } catch (java.lang.NumberFormatException e) { - final Float f = (Float) Cast.processFloatingPointSpecialLiterals($floatStr, true); - if (f == null) { - throw new NumberFormatException("invalid input syntax for type numeric: $c"); - } else { - $evPrim = f.floatValue(); - } - } - """ case StringType => val floatStr = ctx.freshVariable("floatStr", StringType) (c, evPrim, evNull) => + val handleNull = if (ansiEnabled) { + s"""throw new NumberFormatException("invalid input syntax for type numeric: $c");""" + } else { + s"$evNull = true;" + } code""" final String $floatStr = $c.toString(); try { @@ -1569,7 +1522,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit } catch (java.lang.NumberFormatException e) { final Float f = (Float) Cast.processFloatingPointSpecialLiterals($floatStr, true); if (f == null) { - $evNull = true; + $handleNull } else { $evPrim = f.floatValue(); } @@ -1590,25 +1543,14 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit private[this] def castToDoubleCode(from: DataType, ctx: CodegenContext): CastFunction = { from match { - case StringType if ansiEnabled => - val doubleStr = ctx.freshVariable("doubleStr", StringType) - (c, evPrim, evNull) => - code""" - final String $doubleStr = $c.toString(); - try { - $evPrim = Double.valueOf($doubleStr); - } catch (java.lang.NumberFormatException e) { - final Double d = (Double) Cast.processFloatingPointSpecialLiterals($doubleStr, false); - if (d == null) { - throw new NumberFormatException("invalid input syntax for type numeric: $c"); - } else { - $evPrim = d.doubleValue(); - } - } - """ case StringType => val doubleStr = ctx.freshVariable("doubleStr", StringType) (c, evPrim, evNull) => + val handleNull = if (ansiEnabled) { + s"""throw new NumberFormatException("invalid input syntax for type numeric: $c");""" + } else { + s"$evNull = true;" + } code""" final String $doubleStr = $c.toString(); try { @@ -1616,7 +1558,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit } catch (java.lang.NumberFormatException e) { final Double d = (Double) Cast.processFloatingPointSpecialLiterals($doubleStr, false); if (d == null) { - $evNull = true; + $handleNull } else { $evPrim = d.doubleValue(); } diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part2.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part2.sql index 8f1832e75adba..395149e48d5c8 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part2.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part2.sql @@ -284,7 +284,7 @@ window w as (order by f_numeric range between select id, f_numeric, first(id) over w, last(id) over w from numerics window w as (order by f_numeric range between - 1.1 preceding and 'NaN' following); -- error, NaN disallowed + 1.1 preceding and 'NaN' following); -- error, NaN disallowed drop table empsalary; drop table numerics; diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part4.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part4.sql index 932d9558761d2..64ba8e3b7a5ad 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part4.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part4.sql @@ -369,7 +369,7 @@ SELECT i,SUM(v) OVER (ORDER BY i ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) -- [SPARK-29638] Spark handles 'NaN' as 0 in sums -- ensure aggregate over numeric properly recovers from NaN values SELECT a, b, - SUM(b) OVER(ORDER BY A ROWS BETWEEN 1 PRECEDING AND CURRENT ROW) + SUM(b) OVER(ORDER BY A ROWS BETWEEN 1 PRECEDING AND CURRENT ROW) FROM (VALUES(1,1),(2,2),(3,(cast('nan' as int))),(4,3),(5,4)) t(a,b); -- It might be tempting for someone to add an inverse trans function for diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part2.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part2.sql.out index bbbdb1aaeab4a..9183eb659237e 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part2.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part2.sql.out @@ -450,7 +450,7 @@ struct -- !query 28 output diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part4.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part4.sql.out index 5588836f83163..e1c0499b32143 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part4.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part4.sql.out @@ -495,7 +495,7 @@ struct From 2f845c3833117eeb527b6917c4808eb35f388dce Mon Sep 17 00:00:00 2001 From: iRakson Date: Wed, 8 Jan 2020 16:40:24 +0530 Subject: [PATCH 14/15] Fix --- .../spark/sql/catalyst/expressions/Cast.scala | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index e4674931e2c37..aca7f4b1b3a81 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -1390,10 +1390,10 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit code""" UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); if ($c.toByte($wrapper)) { - $evPrim = (byte) $wrapper.value; - } else { - $evNull = true; - } + $evPrim = (byte) $wrapper.value; + } else { + $evNull = true; + } $wrapper = null; """ case BooleanType => @@ -1423,10 +1423,10 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit code""" UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); if ($c.toShort($wrapper)) { - $evPrim = (short) $wrapper.value; - } else { - $evNull = true; - } + $evPrim = (short) $wrapper.value; + } else { + $evNull = true; + } $wrapper = null; """ case BooleanType => @@ -1454,10 +1454,10 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit code""" UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); if ($c.toInt($wrapper)) { - $evPrim = $wrapper.value; - } else { - $evNull = true; - } + $evPrim = $wrapper.value; + } else { + $evNull = true; + } $wrapper = null; """ case BooleanType => From 0cb4edc3a42badfff3f2ecca7b2ec468ad0f0c2d Mon Sep 17 00:00:00 2001 From: iRakson Date: Wed, 8 Jan 2020 18:57:14 +0530 Subject: [PATCH 15/15] Fix --- .../apache/spark/sql/catalyst/expressions/Cast.scala | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index aca7f4b1b3a81..4fd74a4e4658b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -1151,16 +1151,17 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit from match { case StringType => (c, evPrim, evNull) => + val handleException = if (ansiEnabled) { + s"""throw new NumberFormatException("invalid input syntax for type numeric: $c");""" + } else { + s"$evNull =true;" + } code""" try { Decimal $tmp = Decimal.apply(new java.math.BigDecimal($c.toString().trim())); ${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast)} } catch (java.lang.NumberFormatException e) { - if ($ansiEnabled) { - throw new NumberFormatException("invalid input syntax for type numeric: $c"); - } else { - $evNull =true; - } + $handleException } """ case BooleanType =>