Skip to content

Commit e4727cc

Browse files
committed
BinaryOperator should not be doing implicit cast.
1 parent d017861 commit e4727cc

File tree

8 files changed

+170
-72
lines changed

8 files changed

+170
-72
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -221,10 +221,16 @@ object HiveTypeCoercion {
221221
case e if !e.childrenResolved => e
222222

223223
case b @ BinaryOperator(left, right) if left.dataType != right.dataType =>
224-
findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { widestType =>
225-
val newLeft = if (left.dataType == widestType) left else Cast(left, widestType)
226-
val newRight = if (right.dataType == widestType) right else Cast(right, widestType)
227-
b.makeCopy(Array(newLeft, newRight))
224+
findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { commonType =>
225+
// If the expression accepts the tighest common type, cast to that.
226+
// Otherwise, don't do anything with the expression.
227+
if (b.inputType.acceptsType(commonType)) {
228+
val newLeft = if (left.dataType == commonType) left else Cast(left, commonType)
229+
val newRight = if (right.dataType == commonType) right else Cast(right, commonType)
230+
b.makeCopy(Array(newLeft, newRight))
231+
} else {
232+
b
233+
}
228234
}.getOrElse(b) // If there is no applicable conversion, leave expression unchanged.
229235
}
230236
}
@@ -680,7 +686,7 @@ object HiveTypeCoercion {
680686
// Skip nodes who's children have not been resolved yet.
681687
case e if !e.childrenResolved => e
682688

683-
case e: ExpectsInputTypes if (e.inputTypes.nonEmpty) =>
689+
case e: ExpectsInputTypes if e.inputTypes.nonEmpty =>
684690
val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) =>
685691
// If we cannot do the implicit cast, just use the original input.
686692
implicitCast(in, expected).getOrElse(in)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -356,20 +356,32 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
356356
abstract class BinaryOperator extends BinaryExpression {
357357
self: Product =>
358358

359+
/**
360+
* Expected input type from both left/right child expressions, similar to the
361+
* [[ExpectsInputTypes]] trait.
362+
*/
363+
def inputType: AbstractDataType
364+
359365
def symbol: String
360366

367+
override def toString: String = s"($left $symbol $right)"
368+
361369
override def checkInputDataTypes(): TypeCheckResult = {
362-
// First call the checker for ExpectsInputTypes, and then check whether left and right have
363-
// the same type.
364-
super.checkInputDataTypes() match {
365-
case TypeCheckResult.TypeCheckSuccess =>
366-
if (left.dataType != right.dataType) {
367-
TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " +
368-
s"(${left.dataType.simpleString} and ${right.dataType.simpleString}).")
369-
} else {
370-
TypeCheckResult.TypeCheckSuccess
371-
}
372-
case TypeCheckResult.TypeCheckFailure(msg) => TypeCheckResult.TypeCheckFailure(msg)
370+
val mismatches = children.zipWithIndex.collect {
371+
case (child, idx) if !inputType.acceptsType(child.dataType) =>
372+
s"argument ${idx + 1} is expected to be of type ${inputType.simpleString}, " +
373+
s"however, '${child.prettyString}' is of type ${child.dataType.simpleString}."
374+
}
375+
376+
if (mismatches.isEmpty) {
377+
if (left.dataType != right.dataType) {
378+
TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " +
379+
s"(${left.dataType.simpleString} and ${right.dataType.simpleString}).")
380+
} else {
381+
TypeCheckResult.TypeCheckSuccess
382+
}
383+
} else {
384+
TypeCheckResult.TypeCheckFailure(mismatches.mkString(" "))
373385
}
374386
}
375387
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,13 @@ abstract class UnaryArithmetic extends UnaryExpression {
2929
override def dataType: DataType = child.dataType
3030
}
3131

32-
case class UnaryMinus(child: Expression) extends UnaryArithmetic with ExpectsInputTypes {
33-
34-
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
32+
case class UnaryMinus(child: Expression) extends UnaryArithmetic {
3533

3634
override def toString: String = s"-$child"
3735

36+
override def checkInputDataTypes(): TypeCheckResult =
37+
TypeUtils.checkForNumericExpr(child.dataType, "operator -")
38+
3839
private lazy val numeric = TypeUtils.getNumeric(dataType)
3940

4041
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match {
@@ -48,6 +49,9 @@ case class UnaryMinus(child: Expression) extends UnaryArithmetic with ExpectsInp
4849
case class UnaryPositive(child: Expression) extends UnaryArithmetic {
4950
override def prettyName: String = "positive"
5051

52+
override def checkInputDataTypes(): TypeCheckResult =
53+
TypeUtils.checkForNumericExpr(child.dataType, "operator -")
54+
5155
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String =
5256
defineCodeGen(ctx, ev, c => c)
5357

@@ -57,9 +61,9 @@ case class UnaryPositive(child: Expression) extends UnaryArithmetic {
5761
/**
5862
* A function that get the absolute value of the numeric value.
5963
*/
60-
case class Abs(child: Expression) extends UnaryArithmetic with ExpectsInputTypes {
61-
62-
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
64+
case class Abs(child: Expression) extends UnaryArithmetic {
65+
override def checkInputDataTypes(): TypeCheckResult =
66+
TypeUtils.checkForNumericExpr(child.dataType, "function abs")
6367

6468
private lazy val numeric = TypeUtils.getNumeric(dataType)
6569

@@ -91,14 +95,13 @@ private[sql] object BinaryArithmetic {
9195
def unapply(e: BinaryArithmetic): Option[(Expression, Expression)] = Some((e.left, e.right))
9296
}
9397

94-
case class Add(left: Expression, right: Expression)
95-
extends BinaryArithmetic with ExpectsInputTypes {
98+
case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
99+
100+
override def inputType: AbstractDataType = NumericType
96101

97102
override def symbol: String = "+"
98103
override def decimalMethod: String = "$plus"
99104

100-
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, NumericType)
101-
102105
override lazy val resolved =
103106
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
104107

@@ -107,14 +110,13 @@ case class Add(left: Expression, right: Expression)
107110
protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.plus(input1, input2)
108111
}
109112

110-
case class Subtract(left: Expression, right: Expression)
111-
extends BinaryArithmetic with ExpectsInputTypes {
113+
case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic {
114+
115+
override def inputType: AbstractDataType = NumericType
112116

113117
override def symbol: String = "-"
114118
override def decimalMethod: String = "$minus"
115119

116-
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, NumericType)
117-
118120
override lazy val resolved =
119121
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
120122

@@ -123,14 +125,13 @@ case class Subtract(left: Expression, right: Expression)
123125
protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.minus(input1, input2)
124126
}
125127

126-
case class Multiply(left: Expression, right: Expression)
127-
extends BinaryArithmetic with ExpectsInputTypes {
128+
case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic {
129+
130+
override def inputType: AbstractDataType = NumericType
128131

129132
override def symbol: String = "*"
130133
override def decimalMethod: String = "$times"
131134

132-
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, NumericType)
133-
134135
override lazy val resolved =
135136
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
136137

@@ -139,15 +140,14 @@ case class Multiply(left: Expression, right: Expression)
139140
protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2)
140141
}
141142

142-
case class Divide(left: Expression, right: Expression)
143-
extends BinaryArithmetic with ExpectsInputTypes {
143+
case class Divide(left: Expression, right: Expression) extends BinaryArithmetic {
144+
145+
override def inputType: AbstractDataType = NumericType
144146

145147
override def symbol: String = "/"
146148
override def decimalMethod: String = "$div"
147149
override def nullable: Boolean = true
148150

149-
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, NumericType)
150-
151151
override lazy val resolved =
152152
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
153153

@@ -205,15 +205,14 @@ case class Divide(left: Expression, right: Expression)
205205
}
206206
}
207207

208-
case class Remainder(left: Expression, right: Expression)
209-
extends BinaryArithmetic with ExpectsInputTypes {
208+
case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic {
209+
210+
override def inputType: AbstractDataType = NumericType
210211

211212
override def symbol: String = "%"
212213
override def decimalMethod: String = "remainder"
213214
override def nullable: Boolean = true
214215

215-
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, NumericType)
216-
217216
override lazy val resolved =
218217
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
219218

@@ -272,6 +271,10 @@ case class Remainder(left: Expression, right: Expression)
272271
}
273272

274273
case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
274+
// TODO: Remove MaxOf and MinOf, and replace its usage with Greatest and Least.
275+
276+
override def inputType: AbstractDataType = TypeCollection.Ordered
277+
275278
override def nullable: Boolean = left.nullable && right.nullable
276279

277280
protected def checkTypesInternal(t: DataType) =
@@ -326,6 +329,10 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
326329
}
327330

328331
case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {
332+
// TODO: Remove MaxOf and MinOf, and replace its usage with Greatest and Least.
333+
334+
override def inputType: AbstractDataType = TypeCollection.Ordered
335+
329336
override def nullable: Boolean = left.nullable && right.nullable
330337

331338
protected def checkTypesInternal(t: DataType) =

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2021
import org.apache.spark.sql.catalyst.expressions.codegen._
22+
import org.apache.spark.sql.catalyst.util.TypeUtils
2123
import org.apache.spark.sql.types._
2224

2325

@@ -26,13 +28,11 @@ import org.apache.spark.sql.types._
2628
*
2729
* Code generation inherited from BinaryArithmetic.
2830
*/
29-
case class BitwiseAnd(left: Expression, right: Expression)
30-
extends BinaryArithmetic with ExpectsInputTypes {
31+
case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic {
3132

32-
override def symbol: String = "&"
33+
override def inputType: AbstractDataType = TypeCollection.Bitwise
3334

34-
override def inputTypes: Seq[AbstractDataType] =
35-
Seq(TypeCollection(IntegerType, LongType), TypeCollection(IntegerType, LongType))
35+
override def symbol: String = "&"
3636

3737
private lazy val and: (Any, Any) => Any = dataType match {
3838
case ByteType =>
@@ -53,13 +53,11 @@ case class BitwiseAnd(left: Expression, right: Expression)
5353
*
5454
* Code generation inherited from BinaryArithmetic.
5555
*/
56-
case class BitwiseOr(left: Expression, right: Expression)
57-
extends BinaryArithmetic with ExpectsInputTypes {
56+
case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic {
5857

59-
override def symbol: String = "|"
58+
override def inputType: AbstractDataType = TypeCollection.Bitwise
6059

61-
override def inputTypes: Seq[AbstractDataType] =
62-
Seq(TypeCollection(IntegerType, LongType), TypeCollection(IntegerType, LongType))
60+
override def symbol: String = "|"
6361

6462
private lazy val or: (Any, Any) => Any = dataType match {
6563
case ByteType =>
@@ -80,13 +78,11 @@ case class BitwiseOr(left: Expression, right: Expression)
8078
*
8179
* Code generation inherited from BinaryArithmetic.
8280
*/
83-
case class BitwiseXor(left: Expression, right: Expression)
84-
extends BinaryArithmetic with ExpectsInputTypes {
81+
case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic {
8582

86-
override def symbol: String = "^"
83+
override def inputType: AbstractDataType = TypeCollection.Bitwise
8784

88-
override def inputTypes: Seq[AbstractDataType] =
89-
Seq(TypeCollection(IntegerType, LongType), TypeCollection(IntegerType, LongType))
85+
override def symbol: String = "^"
9086

9187
private lazy val xor: (Any, Any) => Any = dataType match {
9288
case ByteType =>
@@ -105,11 +101,12 @@ case class BitwiseXor(left: Expression, right: Expression)
105101
/**
106102
* A function that calculates bitwise not(~) of a number.
107103
*/
108-
case class BitwiseNot(child: Expression) extends UnaryArithmetic with ExpectsInputTypes {
104+
case class BitwiseNot(child: Expression) extends UnaryArithmetic {
109105

110-
override def toString: String = s"~$child"
106+
override def checkInputDataTypes(): TypeCheckResult =
107+
TypeUtils.checkForBitwiseExpr(child.dataType, "operator ~")
111108

112-
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegerType, LongType))
109+
override def toString: String = s"~$child"
113110

114111
private lazy val not: (Any) => Any = dataType match {
115112
case ByteType =>

0 commit comments

Comments
 (0)