1717
1818package org .apache .spark .sql
1919
20+ import scala .collection .JavaConverters ._
21+
2022import org .apache .spark .annotation .Experimental
2123import org .apache .spark .rdd .RDD
2224import org .apache .spark .sql .catalyst .analysis .UnresolvedAlias
25+ import org .apache .spark .api .java .function .{Function => JFunction , Function2 => JFunction2 , _ }
26+
2327import org .apache .spark .sql .catalyst .encoders ._
2428import org .apache .spark .sql .catalyst .expressions ._
2529import org .apache .spark .sql .catalyst .plans .Inner
@@ -151,18 +155,37 @@ class Dataset[T] private[sql](
151155 def transform [U ](t : Dataset [T ] => Dataset [U ]): Dataset [U ] = t(this )
152156
153157 /**
158+ * (Scala-specific)
154159 * Returns a new [[Dataset ]] that only contains elements where `func` returns `true`.
155160 * @since 1.6.0
156161 */
157162 def filter (func : T => Boolean ): Dataset [T ] = mapPartitions(_.filter(func))
158163
159164 /**
165+ * (Java-specific)
166+ * Returns a new [[Dataset ]] that only contains elements where `func` returns `true`.
167+ * @since 1.6.0
168+ */
169+ def filter (func : JFunction [T , java.lang.Boolean ]): Dataset [T ] =
170+ filter(t => func.call(t).booleanValue())
171+
172+ /**
173+ * (Scala-specific)
160174 * Returns a new [[Dataset ]] that contains the result of applying `func` to each element.
161175 * @since 1.6.0
162176 */
163177 def map [U : Encoder ](func : T => U ): Dataset [U ] = mapPartitions(_.map(func))
164178
165179 /**
180+ * (Java-specific)
181+ * Returns a new [[Dataset ]] that contains the result of applying `func` to each element.
182+ * @since 1.6.0
183+ */
184+ def map [U ](func : JFunction [T , U ], encoder : Encoder [U ]): Dataset [U ] =
185+ map(t => func.call(t))(encoder)
186+
187+ /**
188+ * (Scala-specific)
166189 * Returns a new [[Dataset ]] that contains the result of applying `func` to each element.
167190 * @since 1.6.0
168191 */
@@ -177,37 +200,93 @@ class Dataset[T] private[sql](
177200 logicalPlan))
178201 }
179202
203+ /**
204+ * (Java-specific)
205+ * Returns a new [[Dataset ]] that contains the result of applying `func` to each element.
206+ * @since 1.6.0
207+ */
208+ def mapPartitions [U ](
209+ f : FlatMapFunction [java.util.Iterator [T ], U ],
210+ encoder : Encoder [U ]): Dataset [U ] = {
211+ val func : (Iterator [T ]) => Iterator [U ] = x => f.call(x.asJava).iterator().asScala
212+ mapPartitions(func)(encoder)
213+ }
214+
215+ /**
216+ * (Scala-specific)
217+ * Returns a new [[Dataset ]] by first applying a function to all elements of this [[Dataset ]],
218+ * and then flattening the results.
219+ * @since 1.6.0
220+ */
180221 def flatMap [U : Encoder ](func : T => TraversableOnce [U ]): Dataset [U ] =
181222 mapPartitions(_.flatMap(func))
182223
224+ /**
225+ * (Java-specific)
226+ * Returns a new [[Dataset ]] by first applying a function to all elements of this [[Dataset ]],
227+ * and then flattening the results.
228+ * @since 1.6.0
229+ */
230+ def flatMap [U ](f : FlatMapFunction [T , U ], encoder : Encoder [U ]): Dataset [U ] = {
231+ val func : (T ) => Iterable [U ] = x => f.call(x).asScala
232+ flatMap(func)(encoder)
233+ }
234+
183235 /* ************** *
184236 * Side effects *
185237 * ************** */
186238
187239 /**
240+ * (Scala-specific)
188241 * Runs `func` on each element of this Dataset.
189242 * @since 1.6.0
190243 */
191244 def foreach (func : T => Unit ): Unit = rdd.foreach(func)
192245
193246 /**
247+ * (Java-specific)
248+ * Runs `func` on each element of this Dataset.
249+ * @since 1.6.0
250+ */
251+ def foreach (func : VoidFunction [T ]): Unit = foreach(func.call(_))
252+
253+ /**
254+ * (Scala-specific)
194255 * Runs `func` on each partition of this Dataset.
195256 * @since 1.6.0
196257 */
197258 def foreachPartition (func : Iterator [T ] => Unit ): Unit = rdd.foreachPartition(func)
198259
260+ /**
261+ * (Java-specific)
262+ * Runs `func` on each partition of this Dataset.
263+ * @since 1.6.0
264+ */
265+ def foreachPartition (func : VoidFunction [java.util.Iterator [T ]]): Unit =
266+ foreachPartition(it => func.call(it.asJava))
267+
199268 /* ************* *
200269 * Aggregation *
201270 * ************* */
202271
203272 /**
273+ * (Scala-specific)
204274 * Reduces the elements of this Dataset using the specified binary function. The given function
205275 * must be commutative and associative or the result may be non-deterministic.
206276 * @since 1.6.0
207277 */
208278 def reduce (func : (T , T ) => T ): T = rdd.reduce(func)
209279
210280 /**
281+ * (Java-specific)
282+ * Reduces the elements of this Dataset using the specified binary function. The given function
283+ * must be commutative and associative or the result may be non-deterministic.
284+ * @since 1.6.0
285+ */
286+ def reduce (func : JFunction2 [T , T , T ]): T = reduce(func.call(_, _))
287+
288+ /**
289+ * (Scala-specific)
211290 * Aggregates the elements of each partition, and then the results for all the partitions, using a
212291 * given associative and commutative function and a neutral "zero value".
213292 *
@@ -221,6 +300,15 @@ class Dataset[T] private[sql](
221300 def fold (zeroValue : T )(op : (T , T ) => T ): T = rdd.fold(zeroValue)(op)
222301
223302 /**
303+ * (Java-specific)
304+ * Aggregates the elements of each partition, and then the results for all the partitions, using a
305+ * given associative and commutative function and a neutral "zero value".
306+ * @since 1.6.0
307+ */
308+ def fold (zeroValue : T , func : JFunction2 [T , T , T ]): T = fold(zeroValue)(func.call(_, _))
309+
310+ /**
311+ * (Scala-specific)
224312 * Returns a [[GroupedDataset ]] where the data is grouped by the given key function.
225313 * @since 1.6.0
226314 */
@@ -258,6 +346,14 @@ class Dataset[T] private[sql](
258346 keyAttributes)
259347 }
260348
349+ /**
350+ * (Java-specific)
351+ * Returns a [[GroupedDataset ]] where the data is grouped by the given key function.
352+ * @since 1.6.0
353+ */
354+ def groupBy [K ](f : JFunction [T , K ], encoder : Encoder [K ]): GroupedDataset [K , T ] =
355+ groupBy(f.call(_))(encoder)
356+
261357 /* ****************** *
262358 * Typed Relational *
263359 * ****************** */
@@ -267,8 +363,7 @@ class Dataset[T] private[sql](
267363 * {{{
268364 * df.select($"colA", $"colB" + 1)
269365 * }}}
270- * @group dfops
271- * @since 1.3.0
366+ * @since 1.6.0
272367 */
273368 // Copied from Dataframe to make sure we don't have invalid overloads.
274369 @ scala.annotation.varargs
@@ -279,7 +374,7 @@ class Dataset[T] private[sql](
279374 *
280375 * {{{
281376 * val ds = Seq(1, 2, 3).toDS()
282- * val newDS = ds.select(e[Int] ("value + 1"))
377+ * val newDS = ds.select(expr ("value + 1").as[Int] )
283378 * }}}
284379 * @since 1.6.0
285380 */
@@ -405,6 +500,8 @@ class Dataset[T] private[sql](
405500 * This type of join can be useful both for preserving type-safety with the original object
406501 * types as well as working with relational data where either side of the join has column
407502 * names in common.
503+ *
504+ * @since 1.6.0
408505 */
409506 def joinWith [U ](other : Dataset [U ], condition : Column ): Dataset [(T , U )] = {
410507 val left = this .logicalPlan
@@ -438,12 +535,31 @@ class Dataset[T] private[sql](
438535 * Gather to Driver Actions *
439536 * ************************** */
440537
441- /** Returns the first element in this [[Dataset ]]. */
538+ /**
539+ * Returns the first element in this [[Dataset ]].
540+ * @since 1.6.0
541+ */
442542 def first (): T = rdd.first()
443543
444- /** Collects the elements to an Array. */
544+ /**
545+ * Collects the elements to an Array.
546+ * @since 1.6.0
547+ */
445548 def collect (): Array [T ] = rdd.collect()
446549
550+ /**
551+ * (Java-specific)
552+ * Collects the elements to a Java list.
553+ *
554+ * Due to the incompatibility problem between Scala and Java, the return type of [[collect() ]] at
555+ * Java side is `java.lang.Object`, which is not easy to use. Java user can use this method
556+ * instead and keep the generic type for result.
557+ *
558+ * @since 1.6.0
559+ */
560+ def collectAsList (): java.util.List [T ] =
561+ rdd.collect().toSeq.asJava
562+
447563 /** Returns the first `num` elements of this [[Dataset ]] as an Array. */
448564 def take (num : Int ): Array [T ] = rdd.take(num)
449565
0 commit comments