Skip to content

Commit c65e532

Browse files
committed
[SPARK-8772][SQL] Implement implicit type cast for expressions that defines input types.
1 parent 9fd13d5 commit c65e532

File tree

8 files changed

+82
-131
lines changed

8 files changed

+82
-131
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -704,19 +704,46 @@ object HiveTypeCoercion {
704704

705705
/**
706706
* Casts types according to the expected input types for Expressions that have the trait
707-
* [[AutoCastInputTypes]].
707+
* [[ExpectsInputTypes]].
708708
*/
709709
object ImplicitTypeCasts extends Rule[LogicalPlan] {
710710
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
711711
// Skip nodes who's children have not been resolved yet.
712712
case e if !e.childrenResolved => e
713713

714-
case e: AutoCastInputTypes if e.children.map(_.dataType) != e.inputTypes =>
715-
val newC = (e.children, e.children.map(_.dataType), e.inputTypes).zipped.map {
716-
case (child, actual, expected) =>
717-
if (actual == expected) child else Cast(child, expected)
714+
case e: ExpectsInputTypes =>
715+
val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) =>
716+
implicitCast(in, expected)
718717
}
719-
e.withNewChildren(newC)
718+
e.withNewChildren(children)
719+
}
720+
721+
/**
722+
* If needed, cast the expression into the expected type.
723+
* If the implicit cast is not allowed, return the expression itself.
724+
*/
725+
def implicitCast(e: Expression, expectedType: AbstractDataType): Expression = {
726+
(e, expectedType) match {
727+
// Cast null type (usually from null literals) into target types
728+
case (in @ NullType(), target: DataType) => Cast(in, target.defaultConcreteType)
729+
730+
// Implicit cast among numeric types
731+
case (in @ NumericType(), target: NumericType) if in.dataType != target =>
732+
Cast(in, target)
733+
734+
// Implicit cast between date time types
735+
case (in @ DateType(), TimestampType) => Cast(in, TimestampType)
736+
case (in @ TimestampType(), DateType) => Cast(in, DateType)
737+
738+
// Implicit from string to atomic types, and vice versa
739+
case (in @ StringType(), target: AtomicType) if target != StringType =>
740+
Cast(in, target.defaultConcreteType)
741+
case (in, StringType) if in.dataType != StringType =>
742+
Cast(in, StringType)
743+
744+
// Else, just return the same input expression
745+
case (in, _) => in
746+
}
720747
}
721748
}
722749
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
21-
import org.apache.spark.sql.types.DataType
21+
import org.apache.spark.sql.types.AbstractDataType
2222

2323

2424
/**
@@ -32,28 +32,12 @@ trait ExpectsInputTypes { self: Expression =>
3232
*
3333
* The possible values at each position are:
3434
* 1. a specific data type, e.g. LongType, StringType.
35-
* 2. a non-leaf data type, e.g. NumericType, IntegralType, FractionalType.
36-
* 3. a list of specific data types, e.g. Seq(StringType, BinaryType).
35+
* 2. a non-leaf abstract data type, e.g. NumericType, IntegralType, FractionalType.
3736
*/
38-
def inputTypes: Seq[Any]
37+
def inputTypes: Seq[AbstractDataType]
3938

4039
override def checkInputDataTypes(): TypeCheckResult = {
4140
// We will do the type checking in `HiveTypeCoercion`, so always returning success here.
4241
TypeCheckResult.TypeCheckSuccess
4342
}
4443
}
45-
46-
/**
47-
* Expressions that require a specific `DataType` as input should implement this trait
48-
* so that the proper type conversions can be performed in the analyzer.
49-
*/
50-
trait AutoCastInputTypes { self: Expression =>
51-
52-
def inputTypes: Seq[DataType]
53-
54-
override def checkInputDataTypes(): TypeCheckResult = {
55-
// We will always do type casting for `AutoCastInputTypes` in `HiveTypeCoercion`,
56-
// so type mismatch error won't be reported here, but for underling `Cast`s.
57-
TypeCheckResult.TypeCheckSuccess
58-
}
59-
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,7 @@ abstract class LeafMathExpression(c: Double, name: String)
5656
* @param name The short name of the function
5757
*/
5858
abstract class UnaryMathExpression(f: Double => Double, name: String)
59-
extends UnaryExpression with Serializable with AutoCastInputTypes {
60-
self: Product =>
59+
extends UnaryExpression with Serializable with ExpectsInputTypes { self: Product =>
6160

6261
override def inputTypes: Seq[DataType] = Seq(DoubleType)
6362
override def dataType: DataType = DoubleType
@@ -96,7 +95,7 @@ abstract class UnaryMathExpression(f: Double => Double, name: String)
9695
* @param name The short name of the function
9796
*/
9897
abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String)
99-
extends BinaryExpression with Serializable with AutoCastInputTypes { self: Product =>
98+
extends BinaryExpression with Serializable with ExpectsInputTypes { self: Product =>
10099

101100
override def inputTypes: Seq[DataType] = Seq(DoubleType, DoubleType)
102101

@@ -208,7 +207,7 @@ case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadia
208207
}
209208

210209
case class Bin(child: Expression)
211-
extends UnaryExpression with Serializable with AutoCastInputTypes {
210+
extends UnaryExpression with Serializable with ExpectsInputTypes {
212211

213212
override def inputTypes: Seq[DataType] = Seq(LongType)
214213
override def dataType: DataType = StringType

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@ import org.apache.spark.unsafe.types.UTF8String
3131
* A function that calculates an MD5 128-bit checksum and returns it as a hex string
3232
* For input of type [[BinaryType]]
3333
*/
34-
case class Md5(child: Expression)
35-
extends UnaryExpression with AutoCastInputTypes {
34+
case class Md5(child: Expression) extends UnaryExpression with ExpectsInputTypes {
3635

3736
override def dataType: DataType = StringType
3837

@@ -62,12 +61,10 @@ case class Md5(child: Expression)
6261
* the hash length is not one of the permitted values, the return value is NULL.
6362
*/
6463
case class Sha2(left: Expression, right: Expression)
65-
extends BinaryExpression with Serializable with AutoCastInputTypes {
64+
extends BinaryExpression with Serializable with ExpectsInputTypes {
6665

6766
override def dataType: DataType = StringType
6867

69-
override def toString: String = s"SHA2($left, $right)"
70-
7168
override def inputTypes: Seq[DataType] = Seq(BinaryType, IntegerType)
7269

7370
override def eval(input: InternalRow): Any = {
@@ -147,7 +144,7 @@ case class Sha2(left: Expression, right: Expression)
147144
* A function that calculates a sha1 hash value and returns it as a hex string
148145
* For input of type [[BinaryType]] or [[StringType]]
149146
*/
150-
case class Sha1(child: Expression) extends UnaryExpression with AutoCastInputTypes {
147+
case class Sha1(child: Expression) extends UnaryExpression with ExpectsInputTypes {
151148

152149
override def dataType: DataType = StringType
153150

@@ -174,8 +171,7 @@ case class Sha1(child: Expression) extends UnaryExpression with AutoCastInputTyp
174171
* A function that computes a cyclic redundancy check value and returns it as a bigint
175172
* For input of type [[BinaryType]]
176173
*/
177-
case class Crc32(child: Expression)
178-
extends UnaryExpression with AutoCastInputTypes {
174+
case class Crc32(child: Expression) extends UnaryExpression with ExpectsInputTypes {
179175

180176
override def dataType: DataType = LongType
181177

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ trait PredicateHelper {
6969
expr.references.subsetOf(plan.outputSet)
7070
}
7171

72-
case class Not(child: Expression) extends UnaryExpression with Predicate with AutoCastInputTypes {
72+
case class Not(child: Expression) extends UnaryExpression with Predicate with ExpectsInputTypes {
7373
override def toString: String = s"NOT $child"
7474

7575
override def inputTypes: Seq[DataType] = Seq(BooleanType)
@@ -120,11 +120,11 @@ case class InSet(value: Expression, hset: Set[Any])
120120
}
121121

122122
case class And(left: Expression, right: Expression)
123-
extends BinaryOperator with Predicate with AutoCastInputTypes {
123+
extends BinaryExpression with Predicate with ExpectsInputTypes {
124124

125-
override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType)
125+
override def toString: String = s"$left && $right"
126126

127-
override def symbol: String = "&&"
127+
override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType)
128128

129129
override def eval(input: InternalRow): Any = {
130130
val l = left.eval(input)
@@ -169,11 +169,11 @@ case class And(left: Expression, right: Expression)
169169
}
170170

171171
case class Or(left: Expression, right: Expression)
172-
extends BinaryOperator with Predicate with AutoCastInputTypes {
172+
extends BinaryExpression with Predicate with ExpectsInputTypes {
173173

174-
override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType)
174+
override def toString: String = s"$left || $right"
175175

176-
override def symbol: String = "||"
176+
override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType)
177177

178178
override def eval(input: InternalRow): Any = {
179179
val l = left.eval(input)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
2424
import org.apache.spark.sql.types._
2525
import org.apache.spark.unsafe.types.UTF8String
2626

27-
trait StringRegexExpression extends AutoCastInputTypes {
27+
trait StringRegexExpression extends ExpectsInputTypes {
2828
self: BinaryExpression =>
2929

3030
def escape(v: String): String
@@ -111,7 +111,7 @@ case class RLike(left: Expression, right: Expression)
111111
override def toString: String = s"$left RLIKE $right"
112112
}
113113

114-
trait CaseConversionExpression extends AutoCastInputTypes {
114+
trait CaseConversionExpression extends ExpectsInputTypes {
115115
self: UnaryExpression =>
116116

117117
def convert(v: UTF8String): UTF8String
@@ -154,7 +154,7 @@ case class Lower(child: Expression) extends UnaryExpression with CaseConversionE
154154
}
155155

156156
/** A base trait for functions that compare two strings, returning a boolean. */
157-
trait StringComparison extends AutoCastInputTypes {
157+
trait StringComparison extends ExpectsInputTypes {
158158
self: BinaryExpression =>
159159

160160
def compare(l: UTF8String, r: UTF8String): Boolean
@@ -215,7 +215,7 @@ case class EndsWith(left: Expression, right: Expression)
215215
* Defined for String and Binary types.
216216
*/
217217
case class Substring(str: Expression, pos: Expression, len: Expression)
218-
extends Expression with AutoCastInputTypes {
218+
extends Expression with ExpectsInputTypes {
219219

220220
def this(str: Expression, pos: Expression) = {
221221
this(str, pos, Literal(Integer.MAX_VALUE))
@@ -283,7 +283,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression)
283283
/**
284284
* A function that return the length of the given string expression.
285285
*/
286-
case class StringLength(child: Expression) extends UnaryExpression with AutoCastInputTypes {
286+
case class StringLength(child: Expression) extends UnaryExpression with ExpectsInputTypes {
287287
override def dataType: DataType = IntegerType
288288
override def inputTypes: Seq[DataType] = Seq(StringType)
289289

sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala

Lines changed: 2 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717

1818
package org.apache.spark.sql.types
1919

20-
import scala.reflect.ClassTag
21-
import scala.reflect.runtime.universe.{TypeTag, runtimeMirror}
2220
import scala.util.parsing.combinator.RegexParsers
2321

2422
import org.json4s._
@@ -27,9 +25,7 @@ import org.json4s.JsonDSL._
2725
import org.json4s.jackson.JsonMethods._
2826

2927
import org.apache.spark.annotation.DeveloperApi
30-
import org.apache.spark.sql.catalyst.ScalaReflectionLock
3128
import org.apache.spark.sql.catalyst.expressions.Expression
32-
import org.apache.spark.util.Utils
3329

3430

3531
/**
@@ -39,7 +35,7 @@ import org.apache.spark.util.Utils
3935
* @group dataType
4036
*/
4137
@DeveloperApi
42-
abstract class DataType {
38+
abstract class DataType extends AbstractDataType {
4339
/**
4440
* Enables matching against DataType for expressions:
4541
* {{{
@@ -80,84 +76,8 @@ abstract class DataType {
8076
* (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`).
8177
*/
8278
private[spark] def asNullable: DataType
83-
}
84-
85-
86-
/**
87-
* An internal type used to represent everything that is not null, UDTs, arrays, structs, and maps.
88-
*/
89-
protected[sql] abstract class AtomicType extends DataType {
90-
private[sql] type InternalType
91-
@transient private[sql] val tag: TypeTag[InternalType]
92-
private[sql] val ordering: Ordering[InternalType]
93-
94-
@transient private[sql] val classTag = ScalaReflectionLock.synchronized {
95-
val mirror = runtimeMirror(Utils.getSparkClassLoader)
96-
ClassTag[InternalType](mirror.runtimeClass(tag.tpe))
97-
}
98-
}
99-
100-
101-
/**
102-
* :: DeveloperApi ::
103-
* Numeric data types.
104-
*
105-
* @group dataType
106-
*/
107-
abstract class NumericType extends AtomicType {
108-
// Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for
109-
// implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a
110-
// type parameter and add a numeric annotation (i.e., [JvmType : Numeric]). This gets
111-
// desugared by the compiler into an argument to the objects constructor. This means there is no
112-
// longer an no argument constructor and thus the JVM cannot serialize the object anymore.
113-
private[sql] val numeric: Numeric[InternalType]
114-
}
115-
116-
117-
private[sql] object NumericType {
118-
/**
119-
* Enables matching against NumericType for expressions:
120-
* {{{
121-
* case Cast(child @ NumericType(), StringType) =>
122-
* ...
123-
* }}}
124-
*/
125-
def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType]
126-
}
127-
128-
129-
private[sql] object IntegralType {
130-
/**
131-
* Enables matching against IntegralType for expressions:
132-
* {{{
133-
* case Cast(child @ IntegralType(), StringType) =>
134-
* ...
135-
* }}}
136-
*/
137-
def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[IntegralType]
138-
}
139-
140-
141-
private[sql] abstract class IntegralType extends NumericType {
142-
private[sql] val integral: Integral[InternalType]
143-
}
144-
145-
146-
private[sql] object FractionalType {
147-
/**
148-
* Enables matching against FractionalType for expressions:
149-
* {{{
150-
* case Cast(child @ FractionalType(), StringType) =>
151-
* ...
152-
* }}}
153-
*/
154-
def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[FractionalType]
155-
}
156-
15779

158-
private[sql] abstract class FractionalType extends NumericType {
159-
private[sql] val fractional: Fractional[InternalType]
160-
private[sql] val asIntegral: Integral[InternalType]
80+
override def defaultConcreteType: DataType = this
16181
}
16282

16383

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,31 @@ import org.apache.spark.sql.types._
2626

2727
class HiveTypeCoercionSuite extends PlanTest {
2828

29+
test("implicit type cast") {
30+
def shouldCast(from: DataType, to: AbstractDataType): Unit = {
31+
val got = HiveTypeCoercion.ImplicitTypeCasts.implicitCast(Literal.create(null, from), to)
32+
assert(got.dataType === to.defaultConcreteType)
33+
}
34+
35+
// TODO: write the entire implicit cast table out for test cases.
36+
shouldCast(ByteType, IntegerType)
37+
shouldCast(IntegerType, IntegerType)
38+
shouldCast(IntegerType, LongType)
39+
shouldCast(IntegerType, DecimalType.Unlimited)
40+
shouldCast(LongType, IntegerType)
41+
shouldCast(LongType, DecimalType.Unlimited)
42+
43+
shouldCast(DateType, TimestampType)
44+
shouldCast(TimestampType, DateType)
45+
46+
shouldCast(StringType, IntegerType)
47+
shouldCast(StringType, DateType)
48+
shouldCast(StringType, TimestampType)
49+
shouldCast(IntegerType, StringType)
50+
shouldCast(DateType, StringType)
51+
shouldCast(TimestampType, StringType)
52+
}
53+
2954
test("tightest common bound for types") {
3055
def widenTest(t1: DataType, t2: DataType, tightestCommon: Option[DataType]) {
3156
var found = HiveTypeCoercion.findTightestCommonTypeOfTwo(t1, t2)

0 commit comments

Comments
 (0)