Skip to content

Commit 075ce49

Browse files
cloud-fanmarmbrus
authored andcommitted
[SPARK-11313][SQL] implement cogroup on DataSets (support 2 datasets)
A simpler version of #9279, only support 2 datasets. Author: Wenchen Fan <[email protected]> Closes #9324 from cloud-fan/cogroup2.
1 parent 5f1cee6 commit 075ce49

File tree

8 files changed

+257
-0
lines changed

8 files changed

+257
-0
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,7 @@ public String toString() {
591591
build.append(java.lang.Long.toHexString(Platform.getLong(baseObject, baseOffset + i)));
592592
build.append(',');
593593
}
594+
build.deleteCharAt(build.length() - 1);
594595
build.append(']');
595596
return build.toString();
596597
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,3 +513,42 @@ case class MapGroups[K, T, U](
513513
override def missingInput: AttributeSet = AttributeSet.empty
514514
}
515515

516+
/** Factory for constructing new `CoGroup` nodes. */
517+
object CoGroup {
518+
def apply[K : Encoder, Left : Encoder, Right : Encoder, R : Encoder](
519+
func: (K, Iterator[Left], Iterator[Right]) => Iterator[R],
520+
leftGroup: Seq[Attribute],
521+
rightGroup: Seq[Attribute],
522+
left: LogicalPlan,
523+
right: LogicalPlan): CoGroup[K, Left, Right, R] = {
524+
CoGroup(
525+
func,
526+
encoderFor[K],
527+
encoderFor[Left],
528+
encoderFor[Right],
529+
encoderFor[R],
530+
encoderFor[R].schema.toAttributes,
531+
leftGroup,
532+
rightGroup,
533+
left,
534+
right)
535+
}
536+
}
537+
538+
/**
539+
* A relation produced by applying `func` to each grouping key and associated values from left and
540+
* right children.
541+
*/
542+
case class CoGroup[K, Left, Right, R](
543+
func: (K, Iterator[Left], Iterator[Right]) => Iterator[R],
544+
kEncoder: ExpressionEncoder[K],
545+
leftEnc: ExpressionEncoder[Left],
546+
rightEnc: ExpressionEncoder[Right],
547+
rEncoder: ExpressionEncoder[R],
548+
output: Seq[Attribute],
549+
leftGroup: Seq[Attribute],
550+
rightGroup: Seq[Attribute],
551+
left: LogicalPlan,
552+
right: LogicalPlan) extends BinaryNode {
553+
override def missingInput: AttributeSet = AttributeSet.empty
554+
}

sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,4 +65,24 @@ class GroupedDataset[K, T] private[sql](
6565
sqlContext,
6666
MapGroups(f, groupingAttributes, logicalPlan))
6767
}
68+
69+
/**
70+
* Applies the given function to each cogrouped data. For each unique group, the function will
71+
* be passed the grouping key and 2 iterators containing all elements in the group from
72+
* [[Dataset]] `this` and `other`. The function can return an iterator containing elements of an
73+
* arbitrary type which will be returned as a new [[Dataset]].
74+
*/
75+
def cogroup[U, R : Encoder](
76+
other: GroupedDataset[K, U])(
77+
f: (K, Iterator[T], Iterator[U]) => Iterator[R]): Dataset[R] = {
78+
implicit def uEnc: Encoder[U] = other.tEncoder
79+
new Dataset[R](
80+
sqlContext,
81+
CoGroup(
82+
f,
83+
this.groupingAttributes,
84+
other.groupingAttributes,
85+
this.logicalPlan,
86+
other.logicalPlan))
87+
}
6888
}
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution
19+
20+
import org.apache.spark.sql.catalyst.InternalRow
21+
import org.apache.spark.sql.catalyst.expressions.{Ascending, SortOrder, Attribute}
22+
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering
23+
24+
/**
25+
* Iterates over [[GroupedIterator]]s and returns the cogrouped data, i.e. each record is a
26+
* grouping key with its associated values from all [[GroupedIterator]]s.
27+
* Note: we assume the output of each [[GroupedIterator]] is ordered by the grouping key.
28+
*/
29+
class CoGroupedIterator(
30+
left: Iterator[(InternalRow, Iterator[InternalRow])],
31+
right: Iterator[(InternalRow, Iterator[InternalRow])],
32+
groupingSchema: Seq[Attribute])
33+
extends Iterator[(InternalRow, Iterator[InternalRow], Iterator[InternalRow])] {
34+
35+
private val keyOrdering =
36+
GenerateOrdering.generate(groupingSchema.map(SortOrder(_, Ascending)), groupingSchema)
37+
38+
private var currentLeftData: (InternalRow, Iterator[InternalRow]) = _
39+
private var currentRightData: (InternalRow, Iterator[InternalRow]) = _
40+
41+
override def hasNext: Boolean = left.hasNext || right.hasNext
42+
43+
override def next(): (InternalRow, Iterator[InternalRow], Iterator[InternalRow]) = {
44+
if (currentLeftData.eq(null) && left.hasNext) {
45+
currentLeftData = left.next()
46+
}
47+
if (currentRightData.eq(null) && right.hasNext) {
48+
currentRightData = right.next()
49+
}
50+
51+
assert(currentLeftData.ne(null) || currentRightData.ne(null))
52+
53+
if (currentLeftData.eq(null)) {
54+
// left is null, right is not null, consume the right data.
55+
rightOnly()
56+
} else if (currentRightData.eq(null)) {
57+
// left is not null, right is null, consume the left data.
58+
leftOnly()
59+
} else if (currentLeftData._1 == currentRightData._1) {
60+
// left and right have the same grouping key, consume both of them.
61+
val result = (currentLeftData._1, currentLeftData._2, currentRightData._2)
62+
currentLeftData = null
63+
currentRightData = null
64+
result
65+
} else {
66+
val compare = keyOrdering.compare(currentLeftData._1, currentRightData._1)
67+
assert(compare != 0)
68+
if (compare < 0) {
69+
// the grouping key of left is smaller, consume the left data.
70+
leftOnly()
71+
} else {
72+
// the grouping key of right is smaller, consume the right data.
73+
rightOnly()
74+
}
75+
}
76+
}
77+
78+
private def leftOnly(): (InternalRow, Iterator[InternalRow], Iterator[InternalRow]) = {
79+
val result = (currentLeftData._1, currentLeftData._2, Iterator.empty)
80+
currentLeftData = null
81+
result
82+
}
83+
84+
private def rightOnly(): (InternalRow, Iterator[InternalRow], Iterator[InternalRow]) = {
85+
val result = (currentRightData._1, Iterator.empty, currentRightData._2)
86+
currentRightData = null
87+
result
88+
}
89+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
393393
execution.AppendColumns(f, tEnc, uEnc, newCol, planLater(child)) :: Nil
394394
case logical.MapGroups(f, kEnc, tEnc, uEnc, grouping, output, child) =>
395395
execution.MapGroups(f, kEnc, tEnc, uEnc, grouping, output, planLater(child)) :: Nil
396+
case logical.CoGroup(f, kEnc, leftEnc, rightEnc, rEnc, output,
397+
leftGroup, rightGroup, left, right) =>
398+
execution.CoGroup(f, kEnc, leftEnc, rightEnc, rEnc, output, leftGroup, rightGroup,
399+
planLater(left), planLater(right)) :: Nil
396400

397401
case logical.Repartition(numPartitions, shuffle, child) =>
398402
if (shuffle) {

sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,3 +390,44 @@ case class MapGroups[K, T, U](
390390
}
391391
}
392392
}
393+
394+
/**
395+
* Co-groups the data from left and right children, and calls the function with each group and 2
396+
* iterators containing all elements in the group from left and right side.
397+
* The result of this function is encoded and flattened before being output.
398+
*/
399+
case class CoGroup[K, Left, Right, R](
400+
func: (K, Iterator[Left], Iterator[Right]) => Iterator[R],
401+
kEncoder: ExpressionEncoder[K],
402+
leftEnc: ExpressionEncoder[Left],
403+
rightEnc: ExpressionEncoder[Right],
404+
rEncoder: ExpressionEncoder[R],
405+
output: Seq[Attribute],
406+
leftGroup: Seq[Attribute],
407+
rightGroup: Seq[Attribute],
408+
left: SparkPlan,
409+
right: SparkPlan) extends BinaryNode {
410+
411+
override def requiredChildDistribution: Seq[Distribution] =
412+
ClusteredDistribution(leftGroup) :: ClusteredDistribution(rightGroup) :: Nil
413+
414+
override def requiredChildOrdering: Seq[Seq[SortOrder]] =
415+
leftGroup.map(SortOrder(_, Ascending)) :: rightGroup.map(SortOrder(_, Ascending)) :: Nil
416+
417+
override protected def doExecute(): RDD[InternalRow] = {
418+
left.execute().zipPartitions(right.execute()) { (leftData, rightData) =>
419+
val leftGrouped = GroupedIterator(leftData, leftGroup, left.output)
420+
val rightGrouped = GroupedIterator(rightData, rightGroup, right.output)
421+
val groupKeyEncoder = kEncoder.bind(leftGroup)
422+
423+
new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup).flatMap {
424+
case (key, leftResult, rightResult) =>
425+
val result = func(
426+
groupKeyEncoder.fromRow(key),
427+
leftResult.map(leftEnc.fromRow),
428+
rightResult.map(rightEnc.fromRow))
429+
result.map(rEncoder.toRow)
430+
}
431+
}
432+
}
433+
}

sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,4 +202,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
202202
agged,
203203
("a", 30), ("b", 3), ("c", 1))
204204
}
205+
206+
test("cogroup") {
207+
val ds1 = Seq(1 -> "a", 3 -> "abc", 5 -> "hello", 3 -> "foo").toDS()
208+
val ds2 = Seq(2 -> "q", 3 -> "w", 5 -> "e", 5 -> "r").toDS()
209+
val cogrouped = ds1.groupBy(_._1).cogroup(ds2.groupBy(_._1)) { case (key, data1, data2) =>
210+
Iterator(key -> (data1.map(_._2).mkString + "#" + data2.map(_._2).mkString))
211+
}
212+
213+
checkAnswer(
214+
cogrouped,
215+
1 -> "a#", 2 -> "#q", 3 -> "abcfoo#w", 5 -> "hello#er")
216+
}
205217
}
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution
19+
20+
import org.apache.spark.SparkFunSuite
21+
import org.apache.spark.sql.catalyst.dsl.expressions._
22+
import org.apache.spark.sql.catalyst.expressions.ExpressionEvalHelper
23+
24+
class CoGroupedIteratorSuite extends SparkFunSuite with ExpressionEvalHelper {
25+
26+
test("basic") {
27+
val leftInput = Seq(create_row(1, "a"), create_row(1, "b"), create_row(2, "c")).iterator
28+
val rightInput = Seq(create_row(1, 2L), create_row(2, 3L), create_row(3, 4L)).iterator
29+
val leftGrouped = GroupedIterator(leftInput, Seq('i.int.at(0)), Seq('i.int, 's.string))
30+
val rightGrouped = GroupedIterator(rightInput, Seq('i.int.at(0)), Seq('i.int, 'l.long))
31+
val cogrouped = new CoGroupedIterator(leftGrouped, rightGrouped, Seq('i.int))
32+
33+
val result = cogrouped.map {
34+
case (key, leftData, rightData) =>
35+
assert(key.numFields == 1)
36+
(key.getInt(0), leftData.toSeq, rightData.toSeq)
37+
}.toSeq
38+
assert(result ==
39+
(1,
40+
Seq(create_row(1, "a"), create_row(1, "b")),
41+
Seq(create_row(1, 2L))) ::
42+
(2,
43+
Seq(create_row(2, "c")),
44+
Seq(create_row(2, 3L))) ::
45+
(3,
46+
Seq.empty,
47+
Seq(create_row(3, 4L))) ::
48+
Nil
49+
)
50+
}
51+
}

0 commit comments

Comments
 (0)