Skip to content

Commit ed0d1e1

Browse files
xuanyuankingJackey Lee
authored andcommitted
[SPARK-25949][SQL] Add test for PullOutPythonUDFInJoinCondition
## What changes were proposed in this pull request? As comment in apache#22326 (comment), we test the new added optimizer rule by end-to-end test in python side, need to add suites under `org.apache.spark.sql.catalyst.optimizer` like other optimizer rules. ## How was this patch tested? new added UT Closes apache#22955 from xuanyuanking/SPARK-25949. Authored-by: Yuanjian Li <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 6fc2151 commit ed0d1e1

File tree

1 file changed

+171
-0
lines changed

1 file changed

+171
-0
lines changed
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.optimizer
19+
20+
import org.scalatest.Matchers._
21+
22+
import org.apache.spark.api.python.PythonEvalType
23+
import org.apache.spark.sql.AnalysisException
24+
import org.apache.spark.sql.catalyst.dsl.expressions._
25+
import org.apache.spark.sql.catalyst.dsl.plans._
26+
import org.apache.spark.sql.catalyst.expressions.PythonUDF
27+
import org.apache.spark.sql.catalyst.plans._
28+
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
29+
import org.apache.spark.sql.catalyst.rules.RuleExecutor
30+
import org.apache.spark.sql.internal.SQLConf._
31+
import org.apache.spark.sql.types.BooleanType
32+
33+
class PullOutPythonUDFInJoinConditionSuite extends PlanTest {
34+
35+
object Optimize extends RuleExecutor[LogicalPlan] {
36+
val batches =
37+
Batch("Extract PythonUDF From JoinCondition", Once,
38+
PullOutPythonUDFInJoinCondition) ::
39+
Batch("Check Cartesian Products", Once,
40+
CheckCartesianProducts) :: Nil
41+
}
42+
43+
val testRelationLeft = LocalRelation('a.int, 'b.int)
44+
val testRelationRight = LocalRelation('c.int, 'd.int)
45+
46+
// Dummy python UDF for testing. Unable to execute.
47+
val pythonUDF = PythonUDF("pythonUDF", null,
48+
BooleanType,
49+
Seq.empty,
50+
PythonEvalType.SQL_BATCHED_UDF,
51+
udfDeterministic = true)
52+
53+
val unsupportedJoinTypes = Seq(LeftOuter, RightOuter, FullOuter, LeftAnti)
54+
55+
private def comparePlanWithCrossJoinEnable(query: LogicalPlan, expected: LogicalPlan): Unit = {
56+
// AnalysisException thrown by CheckCartesianProducts while spark.sql.crossJoin.enabled=false
57+
val exception = intercept[AnalysisException] {
58+
Optimize.execute(query.analyze)
59+
}
60+
assert(exception.message.startsWith("Detected implicit cartesian product"))
61+
62+
// pull out the python udf while set spark.sql.crossJoin.enabled=true
63+
withSQLConf(CROSS_JOINS_ENABLED.key -> "true") {
64+
val optimized = Optimize.execute(query.analyze)
65+
comparePlans(optimized, expected)
66+
}
67+
}
68+
69+
test("inner join condition with python udf only") {
70+
val query = testRelationLeft.join(
71+
testRelationRight,
72+
joinType = Inner,
73+
condition = Some(pythonUDF))
74+
val expected = testRelationLeft.join(
75+
testRelationRight,
76+
joinType = Inner,
77+
condition = None).where(pythonUDF).analyze
78+
comparePlanWithCrossJoinEnable(query, expected)
79+
}
80+
81+
test("left semi join condition with python udf only") {
82+
val query = testRelationLeft.join(
83+
testRelationRight,
84+
joinType = LeftSemi,
85+
condition = Some(pythonUDF))
86+
val expected = testRelationLeft.join(
87+
testRelationRight,
88+
joinType = Inner,
89+
condition = None).where(pythonUDF).select('a, 'b).analyze
90+
comparePlanWithCrossJoinEnable(query, expected)
91+
}
92+
93+
test("python udf and common condition") {
94+
val query = testRelationLeft.join(
95+
testRelationRight,
96+
joinType = Inner,
97+
condition = Some(pythonUDF && 'a.attr === 'c.attr))
98+
val expected = testRelationLeft.join(
99+
testRelationRight,
100+
joinType = Inner,
101+
condition = Some('a.attr === 'c.attr)).where(pythonUDF).analyze
102+
val optimized = Optimize.execute(query.analyze)
103+
comparePlans(optimized, expected)
104+
}
105+
106+
test("python udf or common condition") {
107+
val query = testRelationLeft.join(
108+
testRelationRight,
109+
joinType = Inner,
110+
condition = Some(pythonUDF || 'a.attr === 'c.attr))
111+
val expected = testRelationLeft.join(
112+
testRelationRight,
113+
joinType = Inner,
114+
condition = None).where(pythonUDF || 'a.attr === 'c.attr).analyze
115+
comparePlanWithCrossJoinEnable(query, expected)
116+
}
117+
118+
test("pull out whole complex condition with multiple python udf") {
119+
val pythonUDF1 = PythonUDF("pythonUDF1", null,
120+
BooleanType,
121+
Seq.empty,
122+
PythonEvalType.SQL_BATCHED_UDF,
123+
udfDeterministic = true)
124+
val condition = (pythonUDF || 'a.attr === 'c.attr) && pythonUDF1
125+
126+
val query = testRelationLeft.join(
127+
testRelationRight,
128+
joinType = Inner,
129+
condition = Some(condition))
130+
val expected = testRelationLeft.join(
131+
testRelationRight,
132+
joinType = Inner,
133+
condition = None).where(condition).analyze
134+
comparePlanWithCrossJoinEnable(query, expected)
135+
}
136+
137+
test("partial pull out complex condition with multiple python udf") {
138+
val pythonUDF1 = PythonUDF("pythonUDF1", null,
139+
BooleanType,
140+
Seq.empty,
141+
PythonEvalType.SQL_BATCHED_UDF,
142+
udfDeterministic = true)
143+
val condition = (pythonUDF || pythonUDF1) && 'a.attr === 'c.attr
144+
145+
val query = testRelationLeft.join(
146+
testRelationRight,
147+
joinType = Inner,
148+
condition = Some(condition))
149+
val expected = testRelationLeft.join(
150+
testRelationRight,
151+
joinType = Inner,
152+
condition = Some('a.attr === 'c.attr)).where(pythonUDF || pythonUDF1).analyze
153+
val optimized = Optimize.execute(query.analyze)
154+
comparePlans(optimized, expected)
155+
}
156+
157+
test("throw an exception for not support join type") {
158+
for (joinType <- unsupportedJoinTypes) {
159+
val thrownException = the [AnalysisException] thrownBy {
160+
val query = testRelationLeft.join(
161+
testRelationRight,
162+
joinType,
163+
condition = Some(pythonUDF))
164+
Optimize.execute(query.analyze)
165+
}
166+
assert(thrownException.message.contentEquals(
167+
s"Using PythonUDF in join condition of join type $joinType is not supported."))
168+
}
169+
}
170+
}
171+

0 commit comments

Comments
 (0)