Skip to content

Commit 03b70da

Browse files
committed
enhance implicit type cast
1 parent adb33d3 commit 03b70da

File tree

13 files changed

+70
-109
lines changed

13 files changed

+70
-109
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -669,7 +669,7 @@ object HiveTypeCoercion {
669669
case b @ BinaryOperator(left, right) if left.dataType != right.dataType =>
670670
findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { commonType =>
671671
if (b.inputType.acceptsType(commonType)) {
672-
// If the expression accepts the tighest common type, cast to that.
672+
// If the expression accepts the tightest common type, cast to that.
673673
val newLeft = if (left.dataType == commonType) left else Cast(left, commonType)
674674
val newRight = if (right.dataType == commonType) right else Cast(right, commonType)
675675
b.makeCopy(Array(newLeft, newRight))
@@ -713,27 +713,22 @@ object HiveTypeCoercion {
713713
@Nullable val ret: Expression = (inType, expectedType) match {
714714

715715
// If the expected type is already a parent of the input type, no need to cast.
716-
case _ if expectedType.isSameType(inType) => e
716+
case _ if expectedType.acceptsType(inType) => e
717717

718718
// Cast null type (usually from null literals) into target types
719719
case (NullType, target) => Cast(e, target.defaultConcreteType)
720720

721-
// If the function accepts any numeric type (i.e. the ADT `NumericType`) and the input is
722-
// already a number, leave it as is.
723-
case (_: NumericType, NumericType) => e
724-
725721
// If the function accepts any numeric type and the input is a string, we follow the hive
726722
// convention and cast that input into a double
727723
case (StringType, NumericType) => Cast(e, NumericType.defaultConcreteType)
728724

729-
// Implicit cast among numeric types
725+
// Implicit cast among numeric types. When we reach here, input type is not acceptable.
726+
730727
// If input is a numeric type but not decimal, and we expect a decimal type,
731728
// cast the input to unlimited precision decimal.
732-
case (_: NumericType, DecimalType) if !inType.isInstanceOf[DecimalType] =>
733-
Cast(e, DecimalType.Unlimited)
729+
case (_: NumericType, DecimalType) => Cast(e, DecimalType.Unlimited)
734730
// For any other numeric types, implicitly cast to each other, e.g. long -> int, int -> long
735-
case (_: NumericType, target: NumericType) if e.dataType != target => Cast(e, target)
736-
case (_: NumericType, target: NumericType) => e
731+
case (_: NumericType, target: NumericType) => Cast(e, target)
737732

738733
// Implicit cast between date time types
739734
case (DateType, TimestampType) => Cast(e, TimestampType)
@@ -747,15 +742,9 @@ object HiveTypeCoercion {
747742
case (StringType, BinaryType) => Cast(e, BinaryType)
748743
case (any, StringType) if any != StringType => Cast(e, StringType)
749744

750-
// Type collection.
751-
// First see if we can find our input type in the type collection. If we can, then just
752-
// use the current expression; otherwise, find the first one we can implicitly cast.
753-
case (_, TypeCollection(types)) =>
754-
if (types.exists(_.isSameType(inType))) {
755-
e
756-
} else {
757-
types.flatMap(implicitCast(e, _)).headOption.orNull
758-
}
745+
// When we reach here, input type is not acceptable for any types in this type collection,
746+
// try to find the first one we can implicitly cast.
747+
case (_, TypeCollection(types)) => types.flatMap(implicitCast(e, _)).headOption.orNull
759748

760749
// Else, just return the same input expression
761750
case _ => null

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
353353
* 2. Two inputs are expected to the be same type. If the two inputs have different types,
354354
* the analyzer will find the tightest common type and do the proper type casting.
355355
*/
356-
abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes {
356+
abstract class BinaryOperator extends BinaryExpression {
357357
self: Product =>
358358

359359
/**
@@ -366,20 +366,16 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes {
366366

367367
override def toString: String = s"($left $symbol $right)"
368368

369-
override def inputTypes: Seq[AbstractDataType] = Seq(inputType, inputType)
370-
371369
override def checkInputDataTypes(): TypeCheckResult = {
372-
// First call the checker for ExpectsInputTypes, and then check whether left and right have
373-
// the same type.
374-
super.checkInputDataTypes() match {
375-
case TypeCheckResult.TypeCheckSuccess =>
376-
if (left.dataType != right.dataType) {
377-
TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " +
378-
s"(${left.dataType.simpleString} and ${right.dataType.simpleString}).")
379-
} else {
380-
TypeCheckResult.TypeCheckSuccess
381-
}
382-
case TypeCheckResult.TypeCheckFailure(msg) => TypeCheckResult.TypeCheckFailure(msg)
370+
// First check whether left and right have the same type, then check if the type is acceptable.
371+
if (left.dataType != right.dataType) {
372+
TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " +
373+
s"(${left.dataType.simpleString} and ${right.dataType.simpleString}).")
374+
} else if (!inputType.acceptsType(left.dataType)) {
375+
TypeCheckResult.TypeCheckFailure(s"'$prettyString' accepts ${inputType.simpleString} type," +
376+
s" not ${left.dataType.simpleString}")
377+
} else {
378+
TypeCheckResult.TypeCheckSuccess
383379
}
384380
}
385381
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,6 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
320320
}
321321

322322
override def symbol: String = "max"
323-
override def prettyName: String = symbol
324323
}
325324

326325
case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {
@@ -375,5 +374,4 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {
375374
}
376375

377376
override def symbol: String = "min"
378-
override def prettyName: String = symbol
379377
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import org.apache.spark.sql.types._
2828
*/
2929
case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic {
3030

31-
override def inputType: AbstractDataType = TypeCollection.Bitwise
31+
override def inputType: AbstractDataType = IntegralType
3232

3333
override def symbol: String = "&"
3434

@@ -53,7 +53,7 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme
5353
*/
5454
case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic {
5555

56-
override def inputType: AbstractDataType = TypeCollection.Bitwise
56+
override def inputType: AbstractDataType = IntegralType
5757

5858
override def symbol: String = "|"
5959

@@ -78,7 +78,7 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet
7878
*/
7979
case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic {
8080

81-
override def inputType: AbstractDataType = TypeCollection.Bitwise
81+
override def inputType: AbstractDataType = IntegralType
8282

8383
override def symbol: String = "^"
8484

@@ -101,7 +101,7 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme
101101
*/
102102
case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInputTypes {
103103

104-
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.Bitwise)
104+
override def inputTypes: Seq[AbstractDataType] = Seq(IntegralType)
105105

106106
override def dataType: DataType = child.dataType
107107

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
3535
TypeCheckResult.TypeCheckFailure(
3636
s"type of predicate expression in If should be boolean, not ${predicate.dataType}")
3737
} else if (trueValue.dataType != falseValue.dataType) {
38-
TypeCheckResult.TypeCheckFailure(
39-
s"differing types in If (${trueValue.dataType} and ${falseValue.dataType}).")
38+
TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " +
39+
s"(${trueValue.dataType.simpleString} and ${falseValue.dataType.simpleString}).")
4040
} else {
4141
TypeCheckResult.TypeCheckSuccess
4242
}

sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala

Lines changed: 12 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -34,32 +34,18 @@ private[sql] abstract class AbstractDataType {
3434
private[sql] def defaultConcreteType: DataType
3535

3636
/**
37-
* Returns true if this data type is the same type as `other`. This is different that equality
38-
* as equality will also consider data type parametrization, such as decimal precision.
37+
* Returns true if `other` is an acceptable input type for a function that expects this,
38+
* possibly abstract DataType.
3939
*
4040
* {{{
4141
* // this should return true
42-
* DecimalType.isSameType(DecimalType(10, 2))
43-
*
44-
* // this should return false
45-
* NumericType.isSameType(DecimalType(10, 2))
46-
* }}}
47-
*/
48-
private[sql] def isSameType(other: DataType): Boolean
49-
50-
/**
51-
* Returns true if `other` is an acceptable input type for a function that expectes this,
52-
* possibly abstract, DataType.
53-
*
54-
* {{{
55-
* // this should return true
56-
* DecimalType.isSameType(DecimalType(10, 2))
42+
* DecimalType.acceptsType(DecimalType(10, 2))
5743
*
5844
* // this should return true as well
5945
* NumericType.acceptsType(DecimalType(10, 2))
6046
* }}}
6147
*/
62-
private[sql] def acceptsType(other: DataType): Boolean = isSameType(other)
48+
private[sql] def acceptsType(other: DataType): Boolean
6349

6450
/** Readable string representation for the type. */
6551
private[sql] def simpleString: String
@@ -83,10 +69,8 @@ private[sql] class TypeCollection(private val types: Seq[AbstractDataType])
8369

8470
override private[sql] def defaultConcreteType: DataType = types.head.defaultConcreteType
8571

86-
override private[sql] def isSameType(other: DataType): Boolean = false
87-
8872
override private[sql] def acceptsType(other: DataType): Boolean =
89-
types.exists(_.isSameType(other))
73+
types.exists(_.acceptsType(other))
9074

9175
override private[sql] def simpleString: String = {
9276
types.map(_.simpleString).mkString("(", " or ", ")")
@@ -107,13 +91,6 @@ private[sql] object TypeCollection {
10791
TimestampType, DateType,
10892
StringType, BinaryType)
10993

110-
/**
111-
* Types that can be used in bitwise operations.
112-
*/
113-
val Bitwise = TypeCollection(
114-
BooleanType,
115-
ByteType, ShortType, IntegerType, LongType)
116-
11794
def apply(types: AbstractDataType*): TypeCollection = new TypeCollection(types)
11895

11996
def unapply(typ: AbstractDataType): Option[Seq[AbstractDataType]] = typ match {
@@ -134,8 +111,6 @@ protected[sql] object AnyDataType extends AbstractDataType {
134111

135112
override private[sql] def simpleString: String = "any"
136113

137-
override private[sql] def isSameType(other: DataType): Boolean = false
138-
139114
override private[sql] def acceptsType(other: DataType): Boolean = true
140115
}
141116

@@ -183,13 +158,11 @@ private[sql] object NumericType extends AbstractDataType {
183158

184159
override private[sql] def simpleString: String = "numeric"
185160

186-
override private[sql] def isSameType(other: DataType): Boolean = false
187-
188161
override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[NumericType]
189162
}
190163

191164

192-
private[sql] object IntegralType {
165+
private[sql] object IntegralType extends AbstractDataType {
193166
/**
194167
* Enables matching against IntegralType for expressions:
195168
* {{{
@@ -198,6 +171,12 @@ private[sql] object IntegralType {
198171
* }}}
199172
*/
200173
def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[IntegralType]
174+
175+
override private[sql] def defaultConcreteType: DataType = IntegerType
176+
177+
override private[sql] def simpleString: String = "integral"
178+
179+
override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[IntegralType]
201180
}
202181

203182

sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ object ArrayType extends AbstractDataType {
2828

2929
override private[sql] def defaultConcreteType: DataType = ArrayType(NullType, containsNull = true)
3030

31-
override private[sql] def isSameType(other: DataType): Boolean = {
31+
override private[sql] def acceptsType(other: DataType): Boolean = {
3232
other.isInstanceOf[ArrayType]
3333
}
3434

sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ abstract class DataType extends AbstractDataType {
7979

8080
override private[sql] def defaultConcreteType: DataType = this
8181

82-
override private[sql] def isSameType(other: DataType): Boolean = this == other
82+
override private[sql] def acceptsType(other: DataType): Boolean = this == other
8383
}
8484

8585

sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ object DecimalType extends AbstractDataType {
8686

8787
override private[sql] def defaultConcreteType: DataType = Unlimited
8888

89-
override private[sql] def isSameType(other: DataType): Boolean = {
89+
override private[sql] def acceptsType(other: DataType): Boolean = {
9090
other.isInstanceOf[DecimalType]
9191
}
9292

sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ object MapType extends AbstractDataType {
7171

7272
override private[sql] def defaultConcreteType: DataType = apply(NullType, NullType)
7373

74-
override private[sql] def isSameType(other: DataType): Boolean = {
74+
override private[sql] def acceptsType(other: DataType): Boolean = {
7575
other.isInstanceOf[MapType]
7676
}
7777

0 commit comments

Comments
 (0)