Skip to content

Commit 678c4da

Browse files
committed
[SPARK-7266] Add ExpectsInputTypes to expressions when possible.
This should gives us better analysis time error messages (rather than runtime) and automatic type casting. Author: Reynold Xin <[email protected]> Closes #5796 from rxin/expected-input-types and squashes the following commits: c900760 [Reynold Xin] [SPARK-7266] Add ExpectsInputTypes to expressions when possible.
1 parent 8055411 commit 678c4da

File tree

5 files changed

+71
-56
lines changed

5 files changed

+71
-56
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -239,37 +239,43 @@ trait HiveTypeCoercion {
239239
a.makeCopy(Array(a.left, Cast(a.right, DoubleType)))
240240

241241
// we should cast all timestamp/date/string compare into string compare
242-
case p: BinaryPredicate if p.left.dataType == StringType
243-
&& p.right.dataType == DateType =>
242+
case p: BinaryComparison if p.left.dataType == StringType &&
243+
p.right.dataType == DateType =>
244244
p.makeCopy(Array(p.left, Cast(p.right, StringType)))
245-
case p: BinaryPredicate if p.left.dataType == DateType
246-
&& p.right.dataType == StringType =>
245+
case p: BinaryComparison if p.left.dataType == DateType &&
246+
p.right.dataType == StringType =>
247247
p.makeCopy(Array(Cast(p.left, StringType), p.right))
248-
case p: BinaryPredicate if p.left.dataType == StringType
249-
&& p.right.dataType == TimestampType =>
248+
case p: BinaryComparison if p.left.dataType == StringType &&
249+
p.right.dataType == TimestampType =>
250250
p.makeCopy(Array(p.left, Cast(p.right, StringType)))
251-
case p: BinaryPredicate if p.left.dataType == TimestampType
252-
&& p.right.dataType == StringType =>
251+
case p: BinaryComparison if p.left.dataType == TimestampType &&
252+
p.right.dataType == StringType =>
253253
p.makeCopy(Array(Cast(p.left, StringType), p.right))
254-
case p: BinaryPredicate if p.left.dataType == TimestampType
255-
&& p.right.dataType == DateType =>
254+
case p: BinaryComparison if p.left.dataType == TimestampType &&
255+
p.right.dataType == DateType =>
256256
p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType)))
257-
case p: BinaryPredicate if p.left.dataType == DateType
258-
&& p.right.dataType == TimestampType =>
257+
case p: BinaryComparison if p.left.dataType == DateType &&
258+
p.right.dataType == TimestampType =>
259259
p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType)))
260260

261-
case p: BinaryPredicate if p.left.dataType == StringType && p.right.dataType != StringType =>
261+
case p: BinaryComparison if p.left.dataType == StringType &&
262+
p.right.dataType != StringType =>
262263
p.makeCopy(Array(Cast(p.left, DoubleType), p.right))
263-
case p: BinaryPredicate if p.left.dataType != StringType && p.right.dataType == StringType =>
264+
case p: BinaryComparison if p.left.dataType != StringType &&
265+
p.right.dataType == StringType =>
264266
p.makeCopy(Array(p.left, Cast(p.right, DoubleType)))
265267

266-
case i @ In(a, b) if a.dataType == DateType && b.forall(_.dataType == StringType) =>
268+
case i @ In(a, b) if a.dataType == DateType &&
269+
b.forall(_.dataType == StringType) =>
267270
i.makeCopy(Array(Cast(a, StringType), b))
268-
case i @ In(a, b) if a.dataType == TimestampType && b.forall(_.dataType == StringType) =>
271+
case i @ In(a, b) if a.dataType == TimestampType &&
272+
b.forall(_.dataType == StringType) =>
269273
i.makeCopy(Array(Cast(a, StringType), b))
270-
case i @ In(a, b) if a.dataType == DateType && b.forall(_.dataType == TimestampType) =>
274+
case i @ In(a, b) if a.dataType == DateType &&
275+
b.forall(_.dataType == TimestampType) =>
271276
i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType))))
272-
case i @ In(a, b) if a.dataType == TimestampType && b.forall(_.dataType == DateType) =>
277+
case i @ In(a, b) if a.dataType == TimestampType &&
278+
b.forall(_.dataType == DateType) =>
273279
i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType))))
274280

275281
case Sum(e) if e.dataType == StringType =>
@@ -420,19 +426,19 @@ trait HiveTypeCoercion {
420426
)
421427

422428
case LessThan(e1 @ DecimalType.Expression(p1, s1),
423-
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
429+
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
424430
LessThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))
425431

426432
case LessThanOrEqual(e1 @ DecimalType.Expression(p1, s1),
427-
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
433+
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
428434
LessThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))
429435

430436
case GreaterThan(e1 @ DecimalType.Expression(p1, s1),
431-
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
437+
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
432438
GreaterThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))
433439

434440
case GreaterThanOrEqual(e1 @ DecimalType.Expression(p1, s1),
435-
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
441+
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
436442
GreaterThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))
437443

438444
// Promote integers inside a binary expression with fixed-precision decimals to decimals,
@@ -481,8 +487,8 @@ trait HiveTypeCoercion {
481487
// No need to change the EqualNullSafe operators, too
482488
case e: EqualNullSafe => e
483489
// Otherwise turn them to Byte types so that there exists and ordering.
484-
case p: BinaryComparison
485-
if p.left.dataType == BooleanType && p.right.dataType == BooleanType =>
490+
case p: BinaryComparison if p.left.dataType == BooleanType &&
491+
p.right.dataType == BooleanType =>
486492
p.makeCopy(Array(Cast(p.left, ByteType), Cast(p.right, ByteType)))
487493
}
488494
}
@@ -564,10 +570,6 @@ trait HiveTypeCoercion {
564570
case Sum(e @ TimestampType()) => Sum(Cast(e, DoubleType))
565571
case Average(e @ TimestampType()) => Average(Cast(e, DoubleType))
566572

567-
// Compatible with Hive
568-
case Substring(e, start, len) if e.dataType != StringType =>
569-
Substring(Cast(e, StringType), start, len)
570-
571573
// Coalesce should return the first non-null value, which could be any column
572574
// from the list. So we need to make sure the return type is deterministic and
573575
// compatible with every child column.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
21-
import org.apache.spark.sql.catalyst.errors.TreeNodeException
2221
import org.apache.spark.sql.catalyst.trees
2322
import org.apache.spark.sql.catalyst.trees.TreeNode
2423
import org.apache.spark.sql.types._
@@ -86,6 +85,8 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
8685

8786
override def foldable: Boolean = left.foldable && right.foldable
8887

88+
override def nullable: Boolean = left.nullable || right.nullable
89+
8990
override def toString: String = s"($left $symbol $right)"
9091
}
9192

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,12 @@ abstract class BinaryArithmetic extends BinaryExpression {
7474

7575
type EvaluatedType = Any
7676

77-
def nullable: Boolean = left.nullable || right.nullable
78-
7977
override lazy val resolved =
8078
left.resolved && right.resolved &&
8179
left.dataType == right.dataType &&
8280
!DecimalType.isFixed(left.dataType)
8381

84-
def dataType: DataType = {
82+
override def dataType: DataType = {
8583
if (!resolved) {
8684
throw new UnresolvedException(this,
8785
s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}")

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -70,16 +70,14 @@ trait PredicateHelper {
7070
expr.references.subsetOf(plan.outputSet)
7171
}
7272

73-
abstract class BinaryPredicate extends BinaryExpression with Predicate {
74-
self: Product =>
75-
override def nullable: Boolean = left.nullable || right.nullable
76-
}
7773

78-
case class Not(child: Expression) extends UnaryExpression with Predicate {
74+
case class Not(child: Expression) extends UnaryExpression with Predicate with ExpectsInputTypes {
7975
override def foldable: Boolean = child.foldable
8076
override def nullable: Boolean = child.nullable
8177
override def toString: String = s"NOT $child"
8278

79+
override def expectedChildTypes: Seq[DataType] = Seq(BooleanType)
80+
8381
override def eval(input: Row): Any = {
8482
child.eval(input) match {
8583
case null => null
@@ -120,7 +118,11 @@ case class InSet(value: Expression, hset: Set[Any])
120118
}
121119
}
122120

123-
case class And(left: Expression, right: Expression) extends BinaryPredicate {
121+
case class And(left: Expression, right: Expression)
122+
extends BinaryExpression with Predicate with ExpectsInputTypes {
123+
124+
override def expectedChildTypes: Seq[DataType] = Seq(BooleanType, BooleanType)
125+
124126
override def symbol: String = "&&"
125127

126128
override def eval(input: Row): Any = {
@@ -142,7 +144,11 @@ case class And(left: Expression, right: Expression) extends BinaryPredicate {
142144
}
143145
}
144146

145-
case class Or(left: Expression, right: Expression) extends BinaryPredicate {
147+
case class Or(left: Expression, right: Expression)
148+
extends BinaryExpression with Predicate with ExpectsInputTypes {
149+
150+
override def expectedChildTypes: Seq[DataType] = Seq(BooleanType, BooleanType)
151+
146152
override def symbol: String = "||"
147153

148154
override def eval(input: Row): Any = {
@@ -164,7 +170,7 @@ case class Or(left: Expression, right: Expression) extends BinaryPredicate {
164170
}
165171
}
166172

167-
abstract class BinaryComparison extends BinaryPredicate {
173+
abstract class BinaryComparison extends BinaryExpression with Predicate {
168174
self: Product =>
169175
}
170176

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import java.util.regex.Pattern
2222
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
2323
import org.apache.spark.sql.types._
2424

25-
trait StringRegexExpression {
25+
trait StringRegexExpression extends ExpectsInputTypes {
2626
self: BinaryExpression =>
2727

2828
type EvaluatedType = Any
@@ -32,6 +32,7 @@ trait StringRegexExpression {
3232

3333
override def nullable: Boolean = left.nullable || right.nullable
3434
override def dataType: DataType = BooleanType
35+
override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType)
3536

3637
// try cache the pattern for Literal
3738
private lazy val cache: Pattern = right match {
@@ -57,11 +58,11 @@ trait StringRegexExpression {
5758
if(r == null) {
5859
null
5960
} else {
60-
val regex = pattern(r.asInstanceOf[UTF8String].toString)
61+
val regex = pattern(r.asInstanceOf[UTF8String].toString())
6162
if(regex == null) {
6263
null
6364
} else {
64-
matches(regex, l.asInstanceOf[UTF8String].toString)
65+
matches(regex, l.asInstanceOf[UTF8String].toString())
6566
}
6667
}
6768
}
@@ -110,16 +111,17 @@ case class RLike(left: Expression, right: Expression)
110111
override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0)
111112
}
112113

113-
trait CaseConversionExpression {
114+
trait CaseConversionExpression extends ExpectsInputTypes {
114115
self: UnaryExpression =>
115116

116117
type EvaluatedType = Any
117118

118119
def convert(v: UTF8String): UTF8String
119120

120121
override def foldable: Boolean = child.foldable
121-
def nullable: Boolean = child.nullable
122-
def dataType: DataType = StringType
122+
override def nullable: Boolean = child.nullable
123+
override def dataType: DataType = StringType
124+
override def expectedChildTypes: Seq[DataType] = Seq(StringType)
123125

124126
override def eval(input: Row): Any = {
125127
val evaluated = child.eval(input)
@@ -136,7 +138,7 @@ trait CaseConversionExpression {
136138
*/
137139
case class Upper(child: Expression) extends UnaryExpression with CaseConversionExpression {
138140

139-
override def convert(v: UTF8String): UTF8String = v.toUpperCase
141+
override def convert(v: UTF8String): UTF8String = v.toUpperCase()
140142

141143
override def toString: String = s"Upper($child)"
142144
}
@@ -146,21 +148,21 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE
146148
*/
147149
case class Lower(child: Expression) extends UnaryExpression with CaseConversionExpression {
148150

149-
override def convert(v: UTF8String): UTF8String = v.toLowerCase
151+
override def convert(v: UTF8String): UTF8String = v.toLowerCase()
150152

151153
override def toString: String = s"Lower($child)"
152154
}
153155

154156
/** A base trait for functions that compare two strings, returning a boolean. */
155157
trait StringComparison {
156-
self: BinaryPredicate =>
158+
self: BinaryExpression =>
159+
160+
def compare(l: UTF8String, r: UTF8String): Boolean
157161

158162
override type EvaluatedType = Any
159163

160164
override def nullable: Boolean = left.nullable || right.nullable
161165

162-
def compare(l: UTF8String, r: UTF8String): Boolean
163-
164166
override def eval(input: Row): Any = {
165167
val leftEval = left.eval(input)
166168
if(leftEval == null) {
@@ -181,31 +183,35 @@ trait StringComparison {
181183
* A function that returns true if the string `left` contains the string `right`.
182184
*/
183185
case class Contains(left: Expression, right: Expression)
184-
extends BinaryPredicate with StringComparison {
186+
extends BinaryExpression with Predicate with StringComparison with ExpectsInputTypes {
185187
override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r)
188+
override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType)
186189
}
187190

188191
/**
189192
* A function that returns true if the string `left` starts with the string `right`.
190193
*/
191194
case class StartsWith(left: Expression, right: Expression)
192-
extends BinaryPredicate with StringComparison {
195+
extends BinaryExpression with Predicate with StringComparison with ExpectsInputTypes {
193196
override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r)
197+
override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType)
194198
}
195199

196200
/**
197201
* A function that returns true if the string `left` ends with the string `right`.
198202
*/
199203
case class EndsWith(left: Expression, right: Expression)
200-
extends BinaryPredicate with StringComparison {
204+
extends BinaryExpression with Predicate with StringComparison with ExpectsInputTypes {
201205
override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r)
206+
override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType)
202207
}
203208

204209
/**
205210
* A function that takes a substring of its first argument starting at a given position.
206211
* Defined for String and Binary types.
207212
*/
208-
case class Substring(str: Expression, pos: Expression, len: Expression) extends Expression {
213+
case class Substring(str: Expression, pos: Expression, len: Expression)
214+
extends Expression with ExpectsInputTypes {
209215

210216
type EvaluatedType = Any
211217

@@ -219,6 +225,8 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends
219225
if (str.dataType == BinaryType) str.dataType else StringType
220226
}
221227

228+
override def expectedChildTypes: Seq[DataType] = Seq(StringType, IntegerType, IntegerType)
229+
222230
override def children: Seq[Expression] = str :: pos :: len :: Nil
223231

224232
@inline
@@ -258,7 +266,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends
258266
val (st, end) = slicePos(start, length, () => ba.length)
259267
ba.slice(st, end)
260268
case s: UTF8String =>
261-
val (st, end) = slicePos(start, length, () => s.length)
269+
val (st, end) = slicePos(start, length, () => s.length())
262270
s.slice(st, end)
263271
}
264272
}

0 commit comments

Comments
 (0)