Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -548,19 +548,22 @@ case class MapGroups[K, T, U](

/** Factory for constructing new `CoGroup` nodes. */
object CoGroup {
def apply[K : Encoder, Left : Encoder, Right : Encoder, R : Encoder](
func: (K, Iterator[Left], Iterator[Right]) => TraversableOnce[R],
def apply[Key, Left, Right, Result : Encoder](
func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result],
keyEnc: ExpressionEncoder[Key],
leftEnc: ExpressionEncoder[Left],
rightEnc: ExpressionEncoder[Right],
leftGroup: Seq[Attribute],
rightGroup: Seq[Attribute],
left: LogicalPlan,
right: LogicalPlan): CoGroup[K, Left, Right, R] = {
right: LogicalPlan): CoGroup[Key, Left, Right, Result] = {
CoGroup(
func,
encoderFor[K],
encoderFor[Left],
encoderFor[Right],
encoderFor[R],
encoderFor[R].schema.toAttributes,
keyEnc,
leftEnc,
rightEnc,
encoderFor[Result],
encoderFor[Result].schema.toAttributes,
leftGroup,
rightGroup,
left,
Expand All @@ -572,12 +575,12 @@ object CoGroup {
* A relation produced by applying `func` to each grouping key and associated values from left and
* right children.
*/
case class CoGroup[K, Left, Right, R](
func: (K, Iterator[Left], Iterator[Right]) => TraversableOnce[R],
kEncoder: ExpressionEncoder[K],
case class CoGroup[Key, Left, Right, Result](
func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result],
keyEnc: ExpressionEncoder[Key],
leftEnc: ExpressionEncoder[Left],
rightEnc: ExpressionEncoder[Right],
rEncoder: ExpressionEncoder[R],
resultEnc: ExpressionEncoder[Result],
output: Seq[Attribute],
leftGroup: Seq[Attribute],
rightGroup: Seq[Attribute],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,11 +304,13 @@ class GroupedDataset[K, V] private[sql](
def cogroup[U, R : Encoder](
other: GroupedDataset[K, U])(
f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = {
implicit def uEnc: Encoder[U] = other.unresolvedVEncoder
new Dataset[R](
sqlContext,
CoGroup(
f,
this.resolvedKEncoder,
this.resolvedVEncoder,
other.resolvedVEncoder,
this.groupingAttributes,
other.groupingAttributes,
this.logicalPlan,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -370,12 +370,12 @@ case class MapGroups[K, T, U](
* iterators containing all elements in the group from left and right side.
* The result of this function is encoded and flattened before being output.
*/
case class CoGroup[K, Left, Right, R](
func: (K, Iterator[Left], Iterator[Right]) => TraversableOnce[R],
kEncoder: ExpressionEncoder[K],
case class CoGroup[Key, Left, Right, Result](
func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result],
keyEnc: ExpressionEncoder[Key],
leftEnc: ExpressionEncoder[Left],
rightEnc: ExpressionEncoder[Right],
rEncoder: ExpressionEncoder[R],
resultEnc: ExpressionEncoder[Result],
output: Seq[Attribute],
leftGroup: Seq[Attribute],
rightGroup: Seq[Attribute],
Expand All @@ -392,15 +392,17 @@ case class CoGroup[K, Left, Right, R](
left.execute().zipPartitions(right.execute()) { (leftData, rightData) =>
val leftGrouped = GroupedIterator(leftData, leftGroup, left.output)
val rightGrouped = GroupedIterator(rightData, rightGroup, right.output)
val groupKeyEncoder = kEncoder.bind(leftGroup)
val boundKeyEnc = keyEnc.bind(leftGroup)
val boundLeftEnc = leftEnc.bind(left.output)
val boundRightEnc = rightEnc.bind(right.output)

new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup).flatMap {
case (key, leftResult, rightResult) =>
val result = func(
groupKeyEncoder.fromRow(key),
leftResult.map(leftEnc.fromRow),
rightResult.map(rightEnc.fromRow))
result.map(rEncoder.toRow)
boundKeyEnc.fromRow(key),
leftResult.map(boundLeftEnc.fromRow),
rightResult.map(boundRightEnc.fromRow))
result.map(resultEnc.toRow)
}
}
}
Expand Down
12 changes: 12 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,18 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
1 -> "a#", 2 -> "#q", 3 -> "abcfoo#w", 5 -> "hello#er")
}

test("cogroup with complex data") {
val ds1 = Seq(1 -> ClassData("a", 1), 2 -> ClassData("b", 2)).toDS()
val ds2 = Seq(2 -> ClassData("c", 3), 3 -> ClassData("d", 4)).toDS()
val cogrouped = ds1.groupBy(_._1).cogroup(ds2.groupBy(_._1)) { case (key, data1, data2) =>
Iterator(key -> (data1.map(_._2.a).mkString + data2.map(_._2.a).mkString))
}

checkAnswer(
cogrouped,
1 -> "a", 2 -> "bc", 3 -> "d")
}

test("SPARK-11436: we should rebind right encoder when join 2 datasets") {
val ds1 = Seq("1", "2").toDS().as("a")
val ds2 = Seq(2, 3).toDS().as("b")
Expand Down