Skip to content

Commit 9fd13d5

Browse files
committed
[SPARK-8770][SQL] Create BinaryOperator abstract class.
Our current BinaryExpression abstract class is not for generic binary expressions, i.e. it requires left/right children to have the same type. However, due to its name, contributors build new binary expressions that don't have that assumption (e.g. Sha) and still extend BinaryExpression. This patch creates a new BinaryOperator abstract class, and update the analyzer o only apply type casting rule there. This patch also adds the notion of "prettyName" to expressions, which defines the user-facing name for the expression. Author: Reynold Xin <[email protected]> Closes #7174 from rxin/binary-opterator and squashes the following commits: f31900d [Reynold Xin] [SPARK-8770][SQL] Create BinaryOperator abstract class. fceb216 [Reynold Xin] Merge branch 'master' of github.com:apache/spark into binary-opterator d8518cf [Reynold Xin] Updated Python tests.
1 parent 3a342de commit 9fd13d5

File tree

15 files changed

+191
-155
lines changed

15 files changed

+191
-155
lines changed

python/pyspark/sql/dataframe.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -802,11 +802,11 @@ def groupBy(self, *cols):
802802
Each element should be a column name (string) or an expression (:class:`Column`).
803803
804804
>>> df.groupBy().avg().collect()
805-
[Row(AVG(age)=3.5)]
805+
[Row(avg(age)=3.5)]
806806
>>> df.groupBy('name').agg({'age': 'mean'}).collect()
807-
[Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)]
807+
[Row(name=u'Alice', avg(age)=2.0), Row(name=u'Bob', avg(age)=5.0)]
808808
>>> df.groupBy(df.name).avg().collect()
809-
[Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)]
809+
[Row(name=u'Alice', avg(age)=2.0), Row(name=u'Bob', avg(age)=5.0)]
810810
>>> df.groupBy(['name', df.age]).count().collect()
811811
[Row(name=u'Bob', age=5, count=1), Row(name=u'Alice', age=2, count=1)]
812812
"""
@@ -864,10 +864,10 @@ def agg(self, *exprs):
864864
(shorthand for ``df.groupBy.agg()``).
865865
866866
>>> df.agg({"age": "max"}).collect()
867-
[Row(MAX(age)=5)]
867+
[Row(max(age)=5)]
868868
>>> from pyspark.sql import functions as F
869869
>>> df.agg(F.min(df.age)).collect()
870-
[Row(MIN(age)=2)]
870+
[Row(min(age)=2)]
871871
"""
872872
return self.groupBy().agg(*exprs)
873873

python/pyspark/sql/functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ def coalesce(*cols):
266266
267267
>>> cDf.select(coalesce(cDf["a"], cDf["b"])).show()
268268
+-------------+
269-
|Coalesce(a,b)|
269+
|coalesce(a,b)|
270270
+-------------+
271271
| null|
272272
| 1|
@@ -275,7 +275,7 @@ def coalesce(*cols):
275275
276276
>>> cDf.select('*', coalesce(cDf["a"], lit(0.0))).show()
277277
+----+----+---------------+
278-
| a| b|Coalesce(a,0.0)|
278+
| a| b|coalesce(a,0.0)|
279279
+----+----+---------------+
280280
|null|null| 0.0|
281281
| 1|null| 1.0|

python/pyspark/sql/group.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,11 @@ def agg(self, *exprs):
7575
7676
>>> gdf = df.groupBy(df.name)
7777
>>> gdf.agg({"*": "count"}).collect()
78-
[Row(name=u'Alice', COUNT(1)=1), Row(name=u'Bob', COUNT(1)=1)]
78+
[Row(name=u'Alice', count(1)=1), Row(name=u'Bob', count(1)=1)]
7979
8080
>>> from pyspark.sql import functions as F
8181
>>> gdf.agg(F.min(df.age)).collect()
82-
[Row(name=u'Alice', MIN(age)=2), Row(name=u'Bob', MIN(age)=5)]
82+
[Row(name=u'Alice', min(age)=2), Row(name=u'Bob', min(age)=5)]
8383
"""
8484
assert exprs, "exprs should not be empty"
8585
if len(exprs) == 1 and isinstance(exprs[0], dict):
@@ -110,9 +110,9 @@ def mean(self, *cols):
110110
:param cols: list of column names (string). Non-numeric columns are ignored.
111111
112112
>>> df.groupBy().mean('age').collect()
113-
[Row(AVG(age)=3.5)]
113+
[Row(avg(age)=3.5)]
114114
>>> df3.groupBy().mean('age', 'height').collect()
115-
[Row(AVG(age)=3.5, AVG(height)=82.5)]
115+
[Row(avg(age)=3.5, avg(height)=82.5)]
116116
"""
117117

118118
@df_varargs_api
@@ -125,9 +125,9 @@ def avg(self, *cols):
125125
:param cols: list of column names (string). Non-numeric columns are ignored.
126126
127127
>>> df.groupBy().avg('age').collect()
128-
[Row(AVG(age)=3.5)]
128+
[Row(avg(age)=3.5)]
129129
>>> df3.groupBy().avg('age', 'height').collect()
130-
[Row(AVG(age)=3.5, AVG(height)=82.5)]
130+
[Row(avg(age)=3.5, avg(height)=82.5)]
131131
"""
132132

133133
@df_varargs_api
@@ -136,9 +136,9 @@ def max(self, *cols):
136136
"""Computes the max value for each numeric columns for each group.
137137
138138
>>> df.groupBy().max('age').collect()
139-
[Row(MAX(age)=5)]
139+
[Row(max(age)=5)]
140140
>>> df3.groupBy().max('age', 'height').collect()
141-
[Row(MAX(age)=5, MAX(height)=85)]
141+
[Row(max(age)=5, max(height)=85)]
142142
"""
143143

144144
@df_varargs_api
@@ -149,9 +149,9 @@ def min(self, *cols):
149149
:param cols: list of column names (string). Non-numeric columns are ignored.
150150
151151
>>> df.groupBy().min('age').collect()
152-
[Row(MIN(age)=2)]
152+
[Row(min(age)=2)]
153153
>>> df3.groupBy().min('age', 'height').collect()
154-
[Row(MIN(age)=2, MIN(height)=80)]
154+
[Row(min(age)=2, min(height)=80)]
155155
"""
156156

157157
@df_varargs_api
@@ -162,9 +162,9 @@ def sum(self, *cols):
162162
:param cols: list of column names (string). Non-numeric columns are ignored.
163163
164164
>>> df.groupBy().sum('age').collect()
165-
[Row(SUM(age)=7)]
165+
[Row(sum(age)=7)]
166166
>>> df3.groupBy().sum('age', 'height').collect()
167-
[Row(SUM(age)=7, SUM(height)=165)]
167+
[Row(sum(age)=7, sum(height)=165)]
168168
"""
169169

170170

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ object HiveTypeCoercion {
150150
* Converts string "NaN"s that are in binary operators with a NaN-able types (Float / Double) to
151151
* the appropriate numeric equivalent.
152152
*/
153+
// TODO: remove this rule and make Cast handle Nan.
153154
object ConvertNaNs extends Rule[LogicalPlan] {
154155
private val StringNaN = Literal("NaN")
155156

@@ -159,19 +160,19 @@ object HiveTypeCoercion {
159160
case e if !e.childrenResolved => e
160161

161162
/* Double Conversions */
162-
case b @ BinaryExpression(StringNaN, right @ DoubleType()) =>
163+
case b @ BinaryOperator(StringNaN, right @ DoubleType()) =>
163164
b.makeCopy(Array(Literal(Double.NaN), right))
164-
case b @ BinaryExpression(left @ DoubleType(), StringNaN) =>
165+
case b @ BinaryOperator(left @ DoubleType(), StringNaN) =>
165166
b.makeCopy(Array(left, Literal(Double.NaN)))
166167

167168
/* Float Conversions */
168-
case b @ BinaryExpression(StringNaN, right @ FloatType()) =>
169+
case b @ BinaryOperator(StringNaN, right @ FloatType()) =>
169170
b.makeCopy(Array(Literal(Float.NaN), right))
170-
case b @ BinaryExpression(left @ FloatType(), StringNaN) =>
171+
case b @ BinaryOperator(left @ FloatType(), StringNaN) =>
171172
b.makeCopy(Array(left, Literal(Float.NaN)))
172173

173174
/* Use float NaN by default to avoid unnecessary type widening */
174-
case b @ BinaryExpression(left @ StringNaN, StringNaN) =>
175+
case b @ BinaryOperator(left @ StringNaN, StringNaN) =>
175176
b.makeCopy(Array(left, Literal(Float.NaN)))
176177
}
177178
}
@@ -245,12 +246,12 @@ object HiveTypeCoercion {
245246

246247
Union(newLeft, newRight)
247248

248-
// Also widen types for BinaryExpressions.
249+
// Also widen types for BinaryOperator.
249250
case q: LogicalPlan => q transformExpressions {
250251
// Skip nodes who's children have not been resolved yet.
251252
case e if !e.childrenResolved => e
252253

253-
case b @ BinaryExpression(left, right) if left.dataType != right.dataType =>
254+
case b @ BinaryOperator(left, right) if left.dataType != right.dataType =>
254255
findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { widestType =>
255256
val newLeft = if (left.dataType == widestType) left else Cast(left, widestType)
256257
val newRight = if (right.dataType == widestType) right else Cast(right, widestType)
@@ -478,7 +479,7 @@ object HiveTypeCoercion {
478479

479480
// Promote integers inside a binary expression with fixed-precision decimals to decimals,
480481
// and fixed-precision decimals in an expression with floats / doubles to doubles
481-
case b @ BinaryExpression(left, right) if left.dataType != right.dataType =>
482+
case b @ BinaryOperator(left, right) if left.dataType != right.dataType =>
482483
(left.dataType, right.dataType) match {
483484
case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) =>
484485
b.makeCopy(Array(Cast(left, intTypeToFixed(t)), right))
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions
19+
20+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
21+
import org.apache.spark.sql.types.DataType
22+
23+
24+
/**
25+
* An trait that gets mixin to define the expected input types of an expression.
26+
*/
27+
trait ExpectsInputTypes { self: Expression =>
28+
29+
/**
30+
* Expected input types from child expressions. The i-th position in the returned seq indicates
31+
* the type requirement for the i-th child.
32+
*
33+
* The possible values at each position are:
34+
* 1. a specific data type, e.g. LongType, StringType.
35+
* 2. a non-leaf data type, e.g. NumericType, IntegralType, FractionalType.
36+
* 3. a list of specific data types, e.g. Seq(StringType, BinaryType).
37+
*/
38+
def inputTypes: Seq[Any]
39+
40+
override def checkInputDataTypes(): TypeCheckResult = {
41+
// We will do the type checking in `HiveTypeCoercion`, so always returning success here.
42+
TypeCheckResult.TypeCheckSuccess
43+
}
44+
}
45+
46+
/**
47+
* Expressions that require a specific `DataType` as input should implement this trait
48+
* so that the proper type conversions can be performed in the analyzer.
49+
*/
50+
trait AutoCastInputTypes { self: Expression =>
51+
52+
def inputTypes: Seq[DataType]
53+
54+
override def checkInputDataTypes(): TypeCheckResult = {
55+
// We will always do type casting for `AutoCastInputTypes` in `HiveTypeCoercion`,
56+
// so type mismatch error won't be reported here, but for underling `Cast`s.
57+
TypeCheckResult.TypeCheckSuccess
58+
}
59+
}

0 commit comments

Comments
 (0)