Skip to content

Commit 3a342de

Browse files
committed
Revert "[SPARK-8770][SQL] Create BinaryOperator abstract class."
This reverts commit 2727789.
1 parent 2727789 commit 3a342de

File tree

12 files changed

+135
-170
lines changed

12 files changed

+135
-170
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,6 @@ object HiveTypeCoercion {
150150
* Converts string "NaN"s that are in binary operators with a NaN-able types (Float / Double) to
151151
* the appropriate numeric equivalent.
152152
*/
153-
// TODO: remove this rule and make Cast handle Nan.
154153
object ConvertNaNs extends Rule[LogicalPlan] {
155154
private val StringNaN = Literal("NaN")
156155

@@ -160,19 +159,19 @@ object HiveTypeCoercion {
160159
case e if !e.childrenResolved => e
161160

162161
/* Double Conversions */
163-
case b @ BinaryOperator(StringNaN, right @ DoubleType()) =>
162+
case b @ BinaryExpression(StringNaN, right @ DoubleType()) =>
164163
b.makeCopy(Array(Literal(Double.NaN), right))
165-
case b @ BinaryOperator(left @ DoubleType(), StringNaN) =>
164+
case b @ BinaryExpression(left @ DoubleType(), StringNaN) =>
166165
b.makeCopy(Array(left, Literal(Double.NaN)))
167166

168167
/* Float Conversions */
169-
case b @ BinaryOperator(StringNaN, right @ FloatType()) =>
168+
case b @ BinaryExpression(StringNaN, right @ FloatType()) =>
170169
b.makeCopy(Array(Literal(Float.NaN), right))
171-
case b @ BinaryOperator(left @ FloatType(), StringNaN) =>
170+
case b @ BinaryExpression(left @ FloatType(), StringNaN) =>
172171
b.makeCopy(Array(left, Literal(Float.NaN)))
173172

174173
/* Use float NaN by default to avoid unnecessary type widening */
175-
case b @ BinaryOperator(left @ StringNaN, StringNaN) =>
174+
case b @ BinaryExpression(left @ StringNaN, StringNaN) =>
176175
b.makeCopy(Array(left, Literal(Float.NaN)))
177176
}
178177
}
@@ -246,12 +245,12 @@ object HiveTypeCoercion {
246245

247246
Union(newLeft, newRight)
248247

249-
// Also widen types for BinaryOperator.
248+
// Also widen types for BinaryExpressions.
250249
case q: LogicalPlan => q transformExpressions {
251250
// Skip nodes who's children have not been resolved yet.
252251
case e if !e.childrenResolved => e
253252

254-
case b @ BinaryOperator(left, right) if left.dataType != right.dataType =>
253+
case b @ BinaryExpression(left, right) if left.dataType != right.dataType =>
255254
findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { widestType =>
256255
val newLeft = if (left.dataType == widestType) left else Cast(left, widestType)
257256
val newRight = if (right.dataType == widestType) right else Cast(right, widestType)
@@ -479,7 +478,7 @@ object HiveTypeCoercion {
479478

480479
// Promote integers inside a binary expression with fixed-precision decimals to decimals,
481480
// and fixed-precision decimals in an expression with floats / doubles to doubles
482-
case b @ BinaryOperator(left, right) if left.dataType != right.dataType =>
481+
case b @ BinaryExpression(left, right) if left.dataType != right.dataType =>
483482
(left.dataType, right.dataType) match {
484483
case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) =>
485484
b.makeCopy(Array(Cast(left, intTypeToFixed(t)), right))

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala

Lines changed: 0 additions & 59 deletions
This file was deleted.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 83 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,17 @@ abstract class Expression extends TreeNode[Expression] {
119119
*/
120120
def childrenResolved: Boolean = children.forall(_.resolved)
121121

122+
/**
123+
* Returns a string representation of this expression that does not have developer centric
124+
* debugging information like the expression id.
125+
*/
126+
def prettyString: String = {
127+
transform {
128+
case a: AttributeReference => PrettyAttribute(a.name)
129+
case u: UnresolvedAttribute => PrettyAttribute(u.name)
130+
}.toString
131+
}
132+
122133
/**
123134
* Returns true when two expressions will always compute the same result, even if they differ
124135
* cosmetically (i.e. capitalization of names in attributes may be different).
@@ -143,40 +154,71 @@ abstract class Expression extends TreeNode[Expression] {
143154
* Note: it's not valid to call this method until `childrenResolved == true`.
144155
*/
145156
def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.TypeCheckSuccess
157+
}
158+
159+
abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] {
160+
self: Product =>
161+
162+
def symbol: String = sys.error(s"BinaryExpressions must override either toString or symbol")
163+
164+
override def foldable: Boolean = left.foldable && right.foldable
165+
166+
override def nullable: Boolean = left.nullable || right.nullable
167+
168+
override def toString: String = s"($left $symbol $right)"
146169

147170
/**
148-
* Returns a user-facing string representation of this expression's name.
149-
* This should usually match the name of the function in SQL.
171+
* Short hand for generating binary evaluation code.
172+
* If either of the sub-expressions is null, the result of this computation
173+
* is assumed to be null.
174+
*
175+
* @param f accepts two variable names and returns Java code to compute the output.
150176
*/
151-
def prettyName: String = getClass.getSimpleName.toLowerCase
177+
protected def defineCodeGen(
178+
ctx: CodeGenContext,
179+
ev: GeneratedExpressionCode,
180+
f: (String, String) => String): String = {
181+
nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => {
182+
s"$result = ${f(eval1, eval2)};"
183+
})
184+
}
152185

153186
/**
154-
* Returns a user-facing string representation of this expression, i.e. does not have developer
155-
* centric debugging information like the expression id.
187+
* Short hand for generating binary evaluation code.
188+
* If either of the sub-expressions is null, the result of this computation
189+
* is assumed to be null.
156190
*/
157-
def prettyString: String = {
158-
transform {
159-
case a: AttributeReference => PrettyAttribute(a.name)
160-
case u: UnresolvedAttribute => PrettyAttribute(u.name)
161-
}.toString
191+
protected def nullSafeCodeGen(
192+
ctx: CodeGenContext,
193+
ev: GeneratedExpressionCode,
194+
f: (String, String, String) => String): String = {
195+
val eval1 = left.gen(ctx)
196+
val eval2 = right.gen(ctx)
197+
val resultCode = f(ev.primitive, eval1.primitive, eval2.primitive)
198+
s"""
199+
${eval1.code}
200+
boolean ${ev.isNull} = ${eval1.isNull};
201+
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
202+
if (!${ev.isNull}) {
203+
${eval2.code}
204+
if (!${eval2.isNull}) {
205+
$resultCode
206+
} else {
207+
${ev.isNull} = true;
208+
}
209+
}
210+
"""
162211
}
163-
164-
override def toString: String = prettyName + children.mkString("(", ",", ")")
165212
}
166213

214+
private[sql] object BinaryExpression {
215+
def unapply(e: BinaryExpression): Option[(Expression, Expression)] = Some((e.left, e.right))
216+
}
167217

168-
/**
169-
* A leaf expression, i.e. one without any child expressions.
170-
*/
171218
abstract class LeafExpression extends Expression with trees.LeafNode[Expression] {
172219
self: Product =>
173220
}
174221

175-
176-
/**
177-
* An expression with one input and one output. The output is by default evaluated to null
178-
* if the input is evaluated to null.
179-
*/
180222
abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] {
181223
self: Product =>
182224

@@ -223,76 +265,39 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio
223265
}
224266
}
225267

226-
227268
/**
228-
* An expression with two inputs and one output. The output is by default evaluated to null
229-
* if any input is evaluated to null.
269+
* An trait that gets mixin to define the expected input types of an expression.
230270
*/
231-
abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] {
232-
self: Product =>
233-
234-
override def foldable: Boolean = left.foldable && right.foldable
235-
236-
override def nullable: Boolean = left.nullable || right.nullable
271+
trait ExpectsInputTypes { self: Expression =>
237272

238273
/**
239-
* Short hand for generating binary evaluation code.
240-
* If either of the sub-expressions is null, the result of this computation
241-
* is assumed to be null.
274+
* Expected input types from child expressions. The i-th position in the returned seq indicates
275+
* the type requirement for the i-th child.
242276
*
243-
* @param f accepts two variable names and returns Java code to compute the output.
277+
* The possible values at each position are:
278+
* 1. a specific data type, e.g. LongType, StringType.
279+
* 2. a non-leaf data type, e.g. NumericType, IntegralType, FractionalType.
280+
* 3. a list of specific data types, e.g. Seq(StringType, BinaryType).
244281
*/
245-
protected def defineCodeGen(
246-
ctx: CodeGenContext,
247-
ev: GeneratedExpressionCode,
248-
f: (String, String) => String): String = {
249-
nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => {
250-
s"$result = ${f(eval1, eval2)};"
251-
})
252-
}
282+
def inputTypes: Seq[Any]
253283

254-
/**
255-
* Short hand for generating binary evaluation code.
256-
* If either of the sub-expressions is null, the result of this computation
257-
* is assumed to be null.
258-
*/
259-
protected def nullSafeCodeGen(
260-
ctx: CodeGenContext,
261-
ev: GeneratedExpressionCode,
262-
f: (String, String, String) => String): String = {
263-
val eval1 = left.gen(ctx)
264-
val eval2 = right.gen(ctx)
265-
val resultCode = f(ev.primitive, eval1.primitive, eval2.primitive)
266-
s"""
267-
${eval1.code}
268-
boolean ${ev.isNull} = ${eval1.isNull};
269-
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
270-
if (!${ev.isNull}) {
271-
${eval2.code}
272-
if (!${eval2.isNull}) {
273-
$resultCode
274-
} else {
275-
${ev.isNull} = true;
276-
}
277-
}
278-
"""
284+
override def checkInputDataTypes(): TypeCheckResult = {
285+
// We will do the type checking in `HiveTypeCoercion`, so always returning success here.
286+
TypeCheckResult.TypeCheckSuccess
279287
}
280288
}
281289

282-
283290
/**
284-
* An expression that has two inputs that are expected to the be same type. If the two inputs have
285-
* different types, the analyzer will find the tightest common type and do the proper type casting.
291+
* Expressions that require a specific `DataType` as input should implement this trait
292+
* so that the proper type conversions can be performed in the analyzer.
286293
*/
287-
abstract class BinaryOperator extends BinaryExpression {
288-
self: Product =>
294+
trait AutoCastInputTypes { self: Expression =>
289295

290-
def symbol: String
296+
def inputTypes: Seq[DataType]
291297

292-
override def toString: String = s"($left $symbol $right)"
293-
}
294-
295-
296-
private[sql] object BinaryOperator {
297-
def unapply(e: BinaryOperator): Option[(Expression, Expression)] = Some((e.left, e.right))
298+
override def checkInputDataTypes(): TypeCheckResult = {
299+
// We will always do type casting for `AutoCastInputTypes` in `HiveTypeCoercion`,
300+
// so type mismatch error won't be reported here, but for underling `Cast`s.
301+
TypeCheckResult.TypeCheckSuccess
302+
}
298303
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ case class ScalaUDF(function: AnyRef, dataType: DataType, children: Seq[Expressi
2929

3030
override def nullable: Boolean = true
3131

32-
override def toString: String = s"UDF(${children.mkString(",")})"
32+
override def toString: String = s"scalaUDF(${children.mkString(",")})"
3333

3434
// scalastyle:off
3535

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[
128128

129129
override def nullable: Boolean = true
130130
override def dataType: DataType = child.dataType
131+
override def toString: String = s"MAX($child)"
131132

132133
override def asPartial: SplitEvaluation = {
133134
val partialMax = Alias(Max(child), "PartialMax")()
@@ -161,6 +162,7 @@ case class Count(child: Expression) extends PartialAggregate with trees.UnaryNod
161162

162163
override def nullable: Boolean = false
163164
override def dataType: LongType.type = LongType
165+
override def toString: String = s"COUNT($child)"
164166

165167
override def asPartial: SplitEvaluation = {
166168
val partialCount = Alias(Count(child), "PartialCount")()
@@ -399,6 +401,8 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN
399401
DoubleType
400402
}
401403

404+
override def toString: String = s"AVG($child)"
405+
402406
override def asPartial: SplitEvaluation = {
403407
child.dataType match {
404408
case DecimalType.Fixed(_, _) | DecimalType.Unlimited =>
@@ -490,6 +494,8 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[
490494
child.dataType
491495
}
492496

497+
override def toString: String = s"SUM($child)"
498+
493499
override def asPartial: SplitEvaluation = {
494500
child.dataType match {
495501
case DecimalType.Fixed(_, _) =>

0 commit comments

Comments
 (0)