From 344b516a414107cf5494951ae74886b6f62f3e5a Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Tue, 6 Nov 2018 20:29:26 +0800 Subject: [PATCH 1/3] Add test for PullOutPythonUDFInJoinCondition --- ...PullOutPythonUDFInJoinConditionSuite.scala | 128 ++++++++++++++++++ 1 file changed, 128 insertions(+) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutPythonUDFInJoinConditionSuite.scala diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutPythonUDFInJoinConditionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutPythonUDFInJoinConditionSuite.scala new file mode 100644 index 000000000000..d494a07e4391 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutPythonUDFInJoinConditionSuite.scala @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.scalatest.Matchers._ + +import org.apache.spark.api.python.PythonEvalType +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.PythonUDF +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.internal.SQLConf._ +import org.apache.spark.sql.types.BooleanType + +class PullOutPythonUDFInJoinConditionSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Extract PythonUDF From JoinCondition", Once, + PullOutPythonUDFInJoinCondition) :: + Batch("Check Cartesian Products", Once, + CheckCartesianProducts) :: Nil + } + + val testRelationLeft = LocalRelation('a.int, 'b.int) + val testRelationRight = LocalRelation('c.int, 'd.int) + + // Dummy python UDF for testing. Unable to execute. + val pythonUDF = PythonUDF("pythonUDF", null, + BooleanType, + Seq.empty, + PythonEvalType.SQL_BATCHED_UDF, + udfDeterministic = true) + + val notSupportJoinTypes = Seq(LeftOuter, RightOuter, FullOuter, LeftAnti) + + test("inner join condition with python udf only") { + val query = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = Some(pythonUDF)) + val expected = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = None).where(pythonUDF).analyze + + // AnalysisException thrown by CheckCartesianProducts while spark.sql.crossJoin.enabled=false + val exception = the [AnalysisException] thrownBy { + Optimize.execute(query.analyze) + } + assert(exception.message.startsWith("Detected implicit cartesian product")) + + // pull out the python udf while set spark.sql.crossJoin.enabled=true + withSQLConf(CROSS_JOINS_ENABLED.key -> "true") { + val optimized = Optimize.execute(query.analyze) + comparePlans(optimized, expected) + } + } + + test("left semi join condition with python udf only") { + val query = testRelationLeft.join( + testRelationRight, + joinType = LeftSemi, + condition = Some(pythonUDF)) + val expected = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = None).where(pythonUDF).select('a, 'b).analyze + + // AnalysisException thrown by CheckCartesianProducts while spark.sql.crossJoin.enabled=false + val exception = the [AnalysisException] thrownBy { + Optimize.execute(query.analyze) + } + assert(exception.message.startsWith("Detected implicit cartesian product")) + + // pull out the python udf while set spark.sql.crossJoin.enabled=true + withSQLConf(CROSS_JOINS_ENABLED.key -> "true") { + val optimized = Optimize.execute(query.analyze) + comparePlans(optimized, expected) + } + } + + test("python udf with other common condition") { + val query = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = Some(pythonUDF && 'a.attr === 'c.attr)) + val expected = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = Some('a.attr === 'c.attr)).where(pythonUDF).analyze + val optimized = Optimize.execute(query.analyze) + comparePlans(optimized, expected) + } + + test("throw an exception for not support join type") { + for (joinType <- notSupportJoinTypes) { + val thrownException = the [AnalysisException] thrownBy { + val query = testRelationLeft.join( + testRelationRight, + joinType, + condition = Some(pythonUDF)) + Optimize.execute(query.analyze) + } + assert(thrownException.message.contentEquals( + s"Using PythonUDF in join condition of join type $joinType is not supported.")) + } + } +} + From 38b15552995355d5e00186fb2b332928a83d248a Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Fri, 9 Nov 2018 15:42:39 +0800 Subject: [PATCH 2/3] More cases for complex condition and comment address --- ...PullOutPythonUDFInJoinConditionSuite.scala | 95 ++++++++++++++----- 1 file changed, 69 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutPythonUDFInJoinConditionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutPythonUDFInJoinConditionSuite.scala index d494a07e4391..f57c1ef6d4dd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutPythonUDFInJoinConditionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutPythonUDFInJoinConditionSuite.scala @@ -50,20 +50,11 @@ class PullOutPythonUDFInJoinConditionSuite extends PlanTest { PythonEvalType.SQL_BATCHED_UDF, udfDeterministic = true) - val notSupportJoinTypes = Seq(LeftOuter, RightOuter, FullOuter, LeftAnti) - - test("inner join condition with python udf only") { - val query = testRelationLeft.join( - testRelationRight, - joinType = Inner, - condition = Some(pythonUDF)) - val expected = testRelationLeft.join( - testRelationRight, - joinType = Inner, - condition = None).where(pythonUDF).analyze + val unsupportedJoinTypes = Seq(LeftOuter, RightOuter, FullOuter, LeftAnti) + private def comparePlansWithConf(query: LogicalPlan, expected: LogicalPlan): Unit = { // AnalysisException thrown by CheckCartesianProducts while spark.sql.crossJoin.enabled=false - val exception = the [AnalysisException] thrownBy { + val exception = intercept[AnalysisException] { Optimize.execute(query.analyze) } assert(exception.message.startsWith("Detected implicit cartesian product")) @@ -75,6 +66,18 @@ class PullOutPythonUDFInJoinConditionSuite extends PlanTest { } } + test("inner join condition with python udf only") { + val query = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = Some(pythonUDF)) + val expected = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = None).where(pythonUDF).analyze + comparePlansWithConf(query, expected) + } + test("left semi join condition with python udf only") { val query = testRelationLeft.join( testRelationRight, @@ -84,21 +87,10 @@ class PullOutPythonUDFInJoinConditionSuite extends PlanTest { testRelationRight, joinType = Inner, condition = None).where(pythonUDF).select('a, 'b).analyze - - // AnalysisException thrown by CheckCartesianProducts while spark.sql.crossJoin.enabled=false - val exception = the [AnalysisException] thrownBy { - Optimize.execute(query.analyze) - } - assert(exception.message.startsWith("Detected implicit cartesian product")) - - // pull out the python udf while set spark.sql.crossJoin.enabled=true - withSQLConf(CROSS_JOINS_ENABLED.key -> "true") { - val optimized = Optimize.execute(query.analyze) - comparePlans(optimized, expected) - } + comparePlansWithConf(query, expected) } - test("python udf with other common condition") { + test("python udf and common condition") { val query = testRelationLeft.join( testRelationRight, joinType = Inner, @@ -111,8 +103,59 @@ class PullOutPythonUDFInJoinConditionSuite extends PlanTest { comparePlans(optimized, expected) } + test("python udf or common condition") { + val query = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = Some(pythonUDF || 'a.attr === 'c.attr)) + val expected = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = None).where(pythonUDF || 'a.attr === 'c.attr).analyze + comparePlansWithConf(query, expected) + } + + test("pull out whole complex condition with multiple python udf") { + val pythonUDF1 = PythonUDF("pythonUDF1", null, + BooleanType, + Seq.empty, + PythonEvalType.SQL_BATCHED_UDF, + udfDeterministic = true) + val condition = (pythonUDF || 'a.attr === 'c.attr) && pythonUDF1 + + val query = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = Some(condition)) + val expected = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = None).where(condition).analyze + comparePlansWithConf(query, expected) + } + + test("partial pull out complex condition with multiple python udf") { + val pythonUDF1 = PythonUDF("pythonUDF1", null, + BooleanType, + Seq.empty, + PythonEvalType.SQL_BATCHED_UDF, + udfDeterministic = true) + val condition = (pythonUDF || pythonUDF1) && 'a.attr === 'c.attr + + val query = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = Some(condition)) + val expected = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = Some('a.attr === 'c.attr)).where(pythonUDF || pythonUDF1).analyze + val optimized = Optimize.execute(query.analyze) + comparePlans(optimized, expected) + } + test("throw an exception for not support join type") { - for (joinType <- notSupportJoinTypes) { + for (joinType <- unsupportedJoinTypes) { val thrownException = the [AnalysisException] thrownBy { val query = testRelationLeft.join( testRelationRight, From 8d04b4c4f084610e8ce8f11590ad4cd537c5952f Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Sun, 11 Nov 2018 22:00:12 +0800 Subject: [PATCH 3/3] better naming --- .../PullOutPythonUDFInJoinConditionSuite.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutPythonUDFInJoinConditionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutPythonUDFInJoinConditionSuite.scala index f57c1ef6d4dd..d3867f2b6bd0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutPythonUDFInJoinConditionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutPythonUDFInJoinConditionSuite.scala @@ -52,7 +52,7 @@ class PullOutPythonUDFInJoinConditionSuite extends PlanTest { val unsupportedJoinTypes = Seq(LeftOuter, RightOuter, FullOuter, LeftAnti) - private def comparePlansWithConf(query: LogicalPlan, expected: LogicalPlan): Unit = { + private def comparePlanWithCrossJoinEnable(query: LogicalPlan, expected: LogicalPlan): Unit = { // AnalysisException thrown by CheckCartesianProducts while spark.sql.crossJoin.enabled=false val exception = intercept[AnalysisException] { Optimize.execute(query.analyze) @@ -75,7 +75,7 @@ class PullOutPythonUDFInJoinConditionSuite extends PlanTest { testRelationRight, joinType = Inner, condition = None).where(pythonUDF).analyze - comparePlansWithConf(query, expected) + comparePlanWithCrossJoinEnable(query, expected) } test("left semi join condition with python udf only") { @@ -87,7 +87,7 @@ class PullOutPythonUDFInJoinConditionSuite extends PlanTest { testRelationRight, joinType = Inner, condition = None).where(pythonUDF).select('a, 'b).analyze - comparePlansWithConf(query, expected) + comparePlanWithCrossJoinEnable(query, expected) } test("python udf and common condition") { @@ -112,7 +112,7 @@ class PullOutPythonUDFInJoinConditionSuite extends PlanTest { testRelationRight, joinType = Inner, condition = None).where(pythonUDF || 'a.attr === 'c.attr).analyze - comparePlansWithConf(query, expected) + comparePlanWithCrossJoinEnable(query, expected) } test("pull out whole complex condition with multiple python udf") { @@ -131,7 +131,7 @@ class PullOutPythonUDFInJoinConditionSuite extends PlanTest { testRelationRight, joinType = Inner, condition = None).where(condition).analyze - comparePlansWithConf(query, expected) + comparePlanWithCrossJoinEnable(query, expected) } test("partial pull out complex condition with multiple python udf") {