Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,9 @@ trait PredicateHelper {
*
* For example consider a join between two relations R(a, b) and S(c, d).
*
* `canEvaluate(EqualTo(a,b), R)` returns `true` where as `canEvaluate(EqualTo(a,c), R)` returns
* `false`.
* - `canEvaluate(EqualTo(a,b), R)` returns `true`
* - `canEvaluate(EqualTo(a,c), R)` returns `false`
* - `canEvaluate(Literal(1), R)` returns `true` as literals CAN be evaluated on any plan
*/
protected def canEvaluate(expr: Expression, plan: LogicalPlan): Boolean =
expr.references.subsetOf(plan.outputSet)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
// find out the first join that have at least one join condition
val conditionalJoin = rest.find { plan =>
val refs = left.outputSet ++ plan.outputSet
conditions.filterNot(canEvaluate(_, left)).filterNot(canEvaluate(_, plan))
conditions
.filterNot(l => l.references.nonEmpty && canEvaluate(l, left))
.filterNot(r => r.references.nonEmpty && canEvaluate(r, plan))
.exists(_.references.subsetOf(refs))
}
// pick the next one if no condition left
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper {
// as join keys.
val predicates = condition.map(splitConjunctivePredicates).getOrElse(Nil)
val joinKeys = predicates.flatMap {
case EqualTo(l, r) if l.references.isEmpty || r.references.isEmpty => None
case EqualTo(l, r) if canEvaluate(l, left) && canEvaluate(r, right) => Some((l, r))
case EqualTo(l, r) if canEvaluate(l, right) && canEvaluate(r, left) => Some((r, l))
// Replace null with default value for joining key, then those rows with null in it could
Expand All @@ -125,6 +126,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper {
case other => None
}
val otherPredicates = predicates.filterNot {
case EqualTo(l, r) if l.references.isEmpty || r.references.isEmpty => false
case EqualTo(l, r) =>
canEvaluate(l, left) && canEvaluate(r, right) ||
canEvaluate(l, right) && canEvaluate(r, left)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,8 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
private def testBucketing(
bucketSpecLeft: Option[BucketSpec],
bucketSpecRight: Option[BucketSpec],
joinColumns: Seq[String],
joinType: String = "inner",
joinCondition: (DataFrame, DataFrame) => Column,
shuffleLeft: Boolean,
shuffleRight: Boolean): Unit = {
withTable("bucketed_table1", "bucketed_table2") {
Expand All @@ -256,12 +257,12 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") {
val t1 = spark.table("bucketed_table1")
val t2 = spark.table("bucketed_table2")
val joined = t1.join(t2, joinCondition(t1, t2, joinColumns))
val joined = t1.join(t2, joinCondition(t1, t2), joinType)

// First check the result is corrected.
checkAnswer(
joined.sort("bucketed_table1.k", "bucketed_table2.k"),
df1.join(df2, joinCondition(df1, df2, joinColumns)).sort("df1.k", "df2.k"))
df1.join(df2, joinCondition(df1, df2), joinType).sort("df1.k", "df2.k"))

assert(joined.queryExecution.executedPlan.isInstanceOf[SortMergeJoinExec])
val joinOperator = joined.queryExecution.executedPlan.asInstanceOf[SortMergeJoinExec]
Expand All @@ -276,47 +277,89 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
}
}

private def joinCondition(left: DataFrame, right: DataFrame, joinCols: Seq[String]): Column = {
private def joinCondition(joinCols: Seq[String]) (left: DataFrame, right: DataFrame): Column = {
joinCols.map(col => left(col) === right(col)).reduce(_ && _)
}

test("avoid shuffle when join 2 bucketed tables") {
val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil))
testBucketing(bucketSpec, bucketSpec, Seq("i", "j"), shuffleLeft = false, shuffleRight = false)
testBucketing(
bucketSpecLeft = bucketSpec,
bucketSpecRight = bucketSpec,
joinCondition = joinCondition(Seq("i", "j")),
shuffleLeft = false,
shuffleRight = false
)
}

// Enable it after fix https://issues.apache.org/jira/browse/SPARK-12704
ignore("avoid shuffle when join keys are a super-set of bucket keys") {
val bucketSpec = Some(BucketSpec(8, Seq("i"), Nil))
testBucketing(bucketSpec, bucketSpec, Seq("i", "j"), shuffleLeft = false, shuffleRight = false)
testBucketing(
bucketSpecLeft = bucketSpec,
bucketSpecRight = bucketSpec,
joinCondition = joinCondition(Seq("i", "j")),
shuffleLeft = false,
shuffleRight = false
)
}

test("only shuffle one side when join bucketed table and non-bucketed table") {
val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil))
testBucketing(bucketSpec, None, Seq("i", "j"), shuffleLeft = false, shuffleRight = true)
testBucketing(
bucketSpecLeft = bucketSpec,
bucketSpecRight = None,
joinCondition = joinCondition(Seq("i", "j")),
shuffleLeft = false,
shuffleRight = true
)
}

test("only shuffle one side when 2 bucketed tables have different bucket number") {
val bucketSpec1 = Some(BucketSpec(8, Seq("i", "j"), Nil))
val bucketSpec2 = Some(BucketSpec(5, Seq("i", "j"), Nil))
testBucketing(bucketSpec1, bucketSpec2, Seq("i", "j"), shuffleLeft = false, shuffleRight = true)
testBucketing(
bucketSpecLeft = bucketSpec1,
bucketSpecRight = bucketSpec2,
joinCondition = joinCondition(Seq("i", "j")),
shuffleLeft = false,
shuffleRight = true
)
}

test("only shuffle one side when 2 bucketed tables have different bucket keys") {
val bucketSpec1 = Some(BucketSpec(8, Seq("i"), Nil))
val bucketSpec2 = Some(BucketSpec(8, Seq("j"), Nil))
testBucketing(bucketSpec1, bucketSpec2, Seq("i"), shuffleLeft = false, shuffleRight = true)
testBucketing(
bucketSpecLeft = bucketSpec1,
bucketSpecRight = bucketSpec2,
joinCondition = joinCondition(Seq("i")),
shuffleLeft = false,
shuffleRight = true
)
}

test("shuffle when join keys are not equal to bucket keys") {
val bucketSpec = Some(BucketSpec(8, Seq("i"), Nil))
testBucketing(bucketSpec, bucketSpec, Seq("j"), shuffleLeft = true, shuffleRight = true)
testBucketing(
bucketSpecLeft = bucketSpec,
bucketSpecRight = bucketSpec,
joinCondition = joinCondition(Seq("j")),
shuffleLeft = true,
shuffleRight = true
)
}

test("shuffle when join 2 bucketed tables with bucketing disabled") {
val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil))
withSQLConf(SQLConf.BUCKETING_ENABLED.key -> "false") {
testBucketing(bucketSpec, bucketSpec, Seq("i", "j"), shuffleLeft = true, shuffleRight = true)
testBucketing(
bucketSpecLeft = bucketSpec,
bucketSpecRight = bucketSpec,
joinCondition = joinCondition(Seq("i", "j")),
shuffleLeft = true,
shuffleRight = true
)
}
}

Expand Down Expand Up @@ -348,6 +391,23 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
}
}

test("SPARK-17698 Join predicates should not contain filter clauses") {
val bucketSpec = Some(BucketSpec(8, Seq("i"), Seq("i")))
testBucketing(
bucketSpecLeft = bucketSpec,
bucketSpecRight = bucketSpec,
joinType = "fullouter",
joinCondition = (left: DataFrame, right: DataFrame) => {
val joinPredicates = left("i") === right("i")
val filterLeft = left("i") === Literal("1")
val filterRight = right("i") === Literal("1")
joinPredicates && filterLeft && filterRight
},
shuffleLeft = false,
shuffleRight = false
)
}

test("error if there exists any malformed bucket files") {
withTable("bucketed_table") {
df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table")
Expand Down