Skip to content

Commit f74ae0e

Browse files
aarondavrxin
authored andcommitted
SPARK-1098: Minor cleanup of ClassTag usage in Java API
Our usage of fake ClassTags in this manner is probably not healthy, but I'm not sure if there's a better solution available, so I just cleaned up and documented the current one. Author: Aaron Davidson <[email protected]> Closes #604 from aarondav/master and squashes the following commits: b398e89 [Aaron Davidson] SPARK-1098: Minor cleanup of ClassTag usage in Java API
1 parent e0d49ad commit f74ae0e

File tree

4 files changed

+108
-100
lines changed

4 files changed

+108
-100
lines changed

core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala

Lines changed: 40 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -17,35 +17,29 @@
1717

1818
package org.apache.spark.api.java
1919

20-
import java.util.{List => JList}
21-
import java.util.Comparator
20+
import java.util.{Comparator, List => JList}
2221

23-
import scala.Tuple2
2422
import scala.collection.JavaConversions._
2523
import scala.reflect.ClassTag
2624

2725
import com.google.common.base.Optional
26+
import org.apache.hadoop.conf.Configuration
2827
import org.apache.hadoop.io.compress.CompressionCodec
29-
import org.apache.hadoop.mapred.JobConf
30-
import org.apache.hadoop.mapred.OutputFormat
28+
import org.apache.hadoop.mapred.{JobConf, OutputFormat}
3129
import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat}
32-
import org.apache.hadoop.conf.Configuration
3330

34-
import org.apache.spark.HashPartitioner
35-
import org.apache.spark.Partitioner
31+
import org.apache.spark.{HashPartitioner, Partitioner}
3632
import org.apache.spark.Partitioner._
3733
import org.apache.spark.SparkContext.rddToPairRDDFunctions
38-
import org.apache.spark.api.java.function.{Function2 => JFunction2}
39-
import org.apache.spark.api.java.function.{Function => JFunction}
40-
import org.apache.spark.partial.BoundedDouble
41-
import org.apache.spark.partial.PartialResult
42-
import org.apache.spark.rdd.RDD
43-
import org.apache.spark.rdd.OrderedRDDFunctions
34+
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
35+
import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2}
36+
import org.apache.spark.partial.{BoundedDouble, PartialResult}
37+
import org.apache.spark.rdd.{OrderedRDDFunctions, RDD}
4438
import org.apache.spark.storage.StorageLevel
4539

46-
47-
class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kClassTag: ClassTag[K],
48-
implicit val vClassTag: ClassTag[V]) extends JavaRDDLike[(K, V), JavaPairRDD[K, V]] {
40+
class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
41+
(implicit val kClassTag: ClassTag[K], implicit val vClassTag: ClassTag[V])
42+
extends JavaRDDLike[(K, V), JavaPairRDD[K, V]] {
4943

5044
override def wrapRDD(rdd: RDD[(K, V)]): JavaPairRDD[K, V] = JavaPairRDD.fromRDD(rdd)
5145

@@ -158,7 +152,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kClassTag: ClassTag[K
158152
mergeValue: JFunction2[C, V, C],
159153
mergeCombiners: JFunction2[C, C, C],
160154
partitioner: Partitioner): JavaPairRDD[K, C] = {
161-
implicit val cm: ClassTag[C] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[C]]
155+
implicit val ctag: ClassTag[C] = fakeClassTag
162156
fromRDD(rdd.combineByKey(
163157
createCombiner,
164158
mergeValue,
@@ -284,19 +278,19 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kClassTag: ClassTag[K
284278
* RDD will be <= us.
285279
*/
286280
def subtractByKey[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, V] = {
287-
implicit val cmw: ClassTag[W] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[W]]
281+
implicit val ctag: ClassTag[W] = fakeClassTag
288282
fromRDD(rdd.subtractByKey(other))
289283
}
290284

291285
/** Return an RDD with the pairs from `this` whose keys are not in `other`. */
292286
def subtractByKey[W](other: JavaPairRDD[K, W], numPartitions: Int): JavaPairRDD[K, V] = {
293-
implicit val cmw: ClassTag[W] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[W]]
287+
implicit val ctag: ClassTag[W] = fakeClassTag
294288
fromRDD(rdd.subtractByKey(other, numPartitions))
295289
}
296290

297291
/** Return an RDD with the pairs from `this` whose keys are not in `other`. */
298292
def subtractByKey[W](other: JavaPairRDD[K, W], p: Partitioner): JavaPairRDD[K, V] = {
299-
implicit val cmw: ClassTag[W] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[W]]
293+
implicit val ctag: ClassTag[W] = fakeClassTag
300294
fromRDD(rdd.subtractByKey(other, p))
301295
}
302296

@@ -345,7 +339,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kClassTag: ClassTag[K
345339
def combineByKey[C](createCombiner: JFunction[V, C],
346340
mergeValue: JFunction2[C, V, C],
347341
mergeCombiners: JFunction2[C, C, C]): JavaPairRDD[K, C] = {
348-
implicit val cm: ClassTag[C] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[C]]
342+
implicit val ctag: ClassTag[C] = fakeClassTag
349343
fromRDD(combineByKey(createCombiner, mergeValue, mergeCombiners, defaultPartitioner(rdd)))
350344
}
351345

@@ -438,7 +432,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kClassTag: ClassTag[K
438432
* this also retains the original RDD's partitioning.
439433
*/
440434
def mapValues[U](f: JFunction[V, U]): JavaPairRDD[K, U] = {
441-
implicit val cm: ClassTag[U] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[U]]
435+
implicit val ctag: ClassTag[U] = fakeClassTag
442436
fromRDD(rdd.mapValues(f))
443437
}
444438

@@ -449,7 +443,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kClassTag: ClassTag[K
449443
def flatMapValues[U](f: JFunction[V, java.lang.Iterable[U]]): JavaPairRDD[K, U] = {
450444
import scala.collection.JavaConverters._
451445
def fn = (x: V) => f.apply(x).asScala
452-
implicit val cm: ClassTag[U] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[U]]
446+
implicit val ctag: ClassTag[U] = fakeClassTag
453447
fromRDD(rdd.flatMapValues(fn))
454448
}
455449

@@ -682,31 +676,35 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kClassTag: ClassTag[K
682676
}
683677

684678
object JavaPairRDD {
685-
def groupByResultToJava[K, T](rdd: RDD[(K, Seq[T])])(implicit kcm: ClassTag[K],
686-
vcm: ClassTag[T]): RDD[(K, JList[T])] =
687-
rddToPairRDDFunctions(rdd).mapValues(seqAsJavaList _)
688-
689-
def cogroupResultToJava[W, K, V](rdd: RDD[(K, (Seq[V], Seq[W]))])(implicit kcm: ClassTag[K],
690-
vcm: ClassTag[V]): RDD[(K, (JList[V], JList[W]))] = rddToPairRDDFunctions(rdd)
691-
.mapValues((x: (Seq[V], Seq[W])) => (seqAsJavaList(x._1), seqAsJavaList(x._2)))
692-
693-
def cogroupResult2ToJava[W1, W2, K, V](rdd: RDD[(K, (Seq[V], Seq[W1],
694-
Seq[W2]))])(implicit kcm: ClassTag[K]) : RDD[(K, (JList[V], JList[W1],
695-
JList[W2]))] = rddToPairRDDFunctions(rdd).mapValues(
696-
(x: (Seq[V], Seq[W1], Seq[W2])) => (seqAsJavaList(x._1),
697-
seqAsJavaList(x._2),
698-
seqAsJavaList(x._3)))
699-
700-
def fromRDD[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]): JavaPairRDD[K, V] =
679+
private[spark]
680+
def groupByResultToJava[K: ClassTag, T](rdd: RDD[(K, Seq[T])]): RDD[(K, JList[T])] = {
681+
rddToPairRDDFunctions(rdd).mapValues(seqAsJavaList)
682+
}
683+
684+
private[spark]
685+
def cogroupResultToJava[K: ClassTag, V, W](
686+
rdd: RDD[(K, (Seq[V], Seq[W]))]): RDD[(K, (JList[V], JList[W]))] = {
687+
rddToPairRDDFunctions(rdd).mapValues(x => (seqAsJavaList(x._1), seqAsJavaList(x._2)))
688+
}
689+
690+
private[spark]
691+
def cogroupResult2ToJava[K: ClassTag, V, W1, W2](
692+
rdd: RDD[(K, (Seq[V], Seq[W1], Seq[W2]))]): RDD[(K, (JList[V], JList[W1], JList[W2]))] = {
693+
rddToPairRDDFunctions(rdd)
694+
.mapValues(x => (seqAsJavaList(x._1), seqAsJavaList(x._2), seqAsJavaList(x._3)))
695+
}
696+
697+
def fromRDD[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]): JavaPairRDD[K, V] = {
701698
new JavaPairRDD[K, V](rdd)
699+
}
702700

703701
implicit def toRDD[K, V](rdd: JavaPairRDD[K, V]): RDD[(K, V)] = rdd.rdd
704702

705703

706704
/** Convert a JavaRDD of key-value pairs to JavaPairRDD. */
707705
def fromJavaRDD[K, V](rdd: JavaRDD[(K, V)]): JavaPairRDD[K, V] = {
708-
implicit val cmk: ClassTag[K] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]]
709-
implicit val cmv: ClassTag[V] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]]
706+
implicit val ctagK: ClassTag[K] = fakeClassTag
707+
implicit val ctagV: ClassTag[V] = fakeClassTag
710708
new JavaPairRDD[K, V](rdd.rdd)
711709
}
712710

core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ import org.apache.spark.rdd.RDD
2424
import org.apache.spark.api.java.function.{Function => JFunction}
2525
import org.apache.spark.storage.StorageLevel
2626

27-
class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) extends
28-
JavaRDDLike[T, JavaRDD[T]] {
27+
class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T])
28+
extends JavaRDDLike[T, JavaRDD[T]] {
2929

3030
override def wrapRDD(rdd: RDD[T]): JavaRDD[T] = JavaRDD.fromRDD(rdd)
3131

core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,23 @@
1717

1818
package org.apache.spark.api.java
1919

20-
import java.util.{List => JList, Comparator}
20+
import java.util.{Comparator, List => JList}
21+
2122
import scala.Tuple2
2223
import scala.collection.JavaConversions._
2324
import scala.reflect.ClassTag
2425

2526
import com.google.common.base.Optional
2627
import org.apache.hadoop.io.compress.CompressionCodec
2728

28-
import org.apache.spark.{SparkContext, Partition, TaskContext}
29-
import org.apache.spark.rdd.RDD
29+
import org.apache.spark.{Partition, SparkContext, TaskContext}
3030
import org.apache.spark.api.java.JavaPairRDD._
31-
import org.apache.spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _}
32-
import org.apache.spark.partial.{PartialResult, BoundedDouble}
31+
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
32+
import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, _}
33+
import org.apache.spark.partial.{BoundedDouble, PartialResult}
34+
import org.apache.spark.rdd.RDD
3335
import org.apache.spark.storage.StorageLevel
3436

35-
3637
trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
3738
def wrapRDD(rdd: RDD[T]): This
3839

@@ -88,8 +89,8 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
8889
* Return a new RDD by applying a function to all elements of this RDD.
8990
*/
9091
def map[K2, V2](f: PairFunction[T, K2, V2]): JavaPairRDD[K2, V2] = {
91-
def cm = implicitly[ClassTag[Tuple2[_, _]]].asInstanceOf[ClassTag[Tuple2[K2, V2]]]
92-
new JavaPairRDD(rdd.map(f)(cm))(f.keyType(), f.valueType())
92+
val ctag = implicitly[ClassTag[Tuple2[K2, V2]]]
93+
new JavaPairRDD(rdd.map(f)(ctag))(f.keyType(), f.valueType())
9394
}
9495

9596
/**
@@ -119,8 +120,8 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
119120
def flatMap[K2, V2](f: PairFlatMapFunction[T, K2, V2]): JavaPairRDD[K2, V2] = {
120121
import scala.collection.JavaConverters._
121122
def fn = (x: T) => f.apply(x).asScala
122-
def cm = implicitly[ClassTag[Tuple2[_, _]]].asInstanceOf[ClassTag[Tuple2[K2, V2]]]
123-
JavaPairRDD.fromRDD(rdd.flatMap(fn)(cm))(f.keyType(), f.valueType())
123+
val ctag = implicitly[ClassTag[Tuple2[K2, V2]]]
124+
JavaPairRDD.fromRDD(rdd.flatMap(fn)(ctag))(f.keyType(), f.valueType())
124125
}
125126

126127
/**
@@ -202,21 +203,19 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
202203
* mapping to that key.
203204
*/
204205
def groupBy[K](f: JFunction[T, K]): JavaPairRDD[K, JList[T]] = {
205-
implicit val kcm: ClassTag[K] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]]
206-
implicit val vcm: ClassTag[JList[T]] =
207-
implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[JList[T]]]
208-
JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f)(f.returnType)))(kcm, vcm)
206+
implicit val ctagK: ClassTag[K] = fakeClassTag
207+
implicit val ctagV: ClassTag[JList[T]] = fakeClassTag
208+
JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f)(f.returnType)))
209209
}
210210

211211
/**
212212
* Return an RDD of grouped elements. Each group consists of a key and a sequence of elements
213213
* mapping to that key.
214214
*/
215215
def groupBy[K](f: JFunction[T, K], numPartitions: Int): JavaPairRDD[K, JList[T]] = {
216-
implicit val kcm: ClassTag[K] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]]
217-
implicit val vcm: ClassTag[JList[T]] =
218-
implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[JList[T]]]
219-
JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f, numPartitions)(f.returnType)))(kcm, vcm)
216+
implicit val ctagK: ClassTag[K] = fakeClassTag
217+
implicit val ctagV: ClassTag[JList[T]] = fakeClassTag
218+
JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f, numPartitions)(f.returnType)))
220219
}
221220

222221
/**
@@ -407,7 +406,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
407406
* Creates tuples of the elements in this RDD by applying `f`.
408407
*/
409408
def keyBy[K](f: JFunction[T, K]): JavaPairRDD[K, T] = {
410-
implicit val kcm: ClassTag[K] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]]
409+
implicit val ctag: ClassTag[K] = fakeClassTag
411410
JavaPairRDD.fromRDD(rdd.keyBy(f))
412411
}
413412

0 commit comments

Comments
 (0)