Skip to content

Commit 02748c9

Browse files
cloud-fanrxin
authored andcommitted
[SPARK-11269][SQL] Java API support & test cases for Dataset
This simply brings #9358 up-to-date. Author: Wenchen Fan <[email protected]> Author: Reynold Xin <[email protected]> Closes #9528 from rxin/dataset-java. (cherry picked from commit 7e9a9e6) Signed-off-by: Reynold Xin <[email protected]>
1 parent b58f1ce commit 02748c9

File tree

8 files changed

+644
-12
lines changed

8 files changed

+644
-12
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala

Lines changed: 120 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717

1818
package org.apache.spark.sql.catalyst.encoders
1919

20-
21-
2220
import scala.reflect.ClassTag
2321

24-
import org.apache.spark.sql.types.StructType
22+
import org.apache.spark.util.Utils
23+
import org.apache.spark.sql.types.{DataType, ObjectType, StructField, StructType}
24+
import org.apache.spark.sql.catalyst.expressions._
2525

2626
/**
2727
* Used to convert a JVM object of type `T` to and from the internal Spark SQL representation.
@@ -37,3 +37,120 @@ trait Encoder[T] extends Serializable {
3737
/** A ClassTag that can be used to construct and Array to contain a collection of `T`. */
3838
def clsTag: ClassTag[T]
3939
}
40+
41+
object Encoder {
42+
import scala.reflect.runtime.universe._
43+
44+
def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder(flat = true)
45+
def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder(flat = true)
46+
def SHORT: Encoder[java.lang.Short] = ExpressionEncoder(flat = true)
47+
def INT: Encoder[java.lang.Integer] = ExpressionEncoder(flat = true)
48+
def LONG: Encoder[java.lang.Long] = ExpressionEncoder(flat = true)
49+
def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder(flat = true)
50+
def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder(flat = true)
51+
def STRING: Encoder[java.lang.String] = ExpressionEncoder(flat = true)
52+
53+
def tuple[T1, T2](enc1: Encoder[T1], enc2: Encoder[T2]): Encoder[(T1, T2)] = {
54+
tuple(Seq(enc1, enc2).map(_.asInstanceOf[ExpressionEncoder[_]]))
55+
.asInstanceOf[ExpressionEncoder[(T1, T2)]]
56+
}
57+
58+
def tuple[T1, T2, T3](
59+
enc1: Encoder[T1],
60+
enc2: Encoder[T2],
61+
enc3: Encoder[T3]): Encoder[(T1, T2, T3)] = {
62+
tuple(Seq(enc1, enc2, enc3).map(_.asInstanceOf[ExpressionEncoder[_]]))
63+
.asInstanceOf[ExpressionEncoder[(T1, T2, T3)]]
64+
}
65+
66+
def tuple[T1, T2, T3, T4](
67+
enc1: Encoder[T1],
68+
enc2: Encoder[T2],
69+
enc3: Encoder[T3],
70+
enc4: Encoder[T4]): Encoder[(T1, T2, T3, T4)] = {
71+
tuple(Seq(enc1, enc2, enc3, enc4).map(_.asInstanceOf[ExpressionEncoder[_]]))
72+
.asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4)]]
73+
}
74+
75+
def tuple[T1, T2, T3, T4, T5](
76+
enc1: Encoder[T1],
77+
enc2: Encoder[T2],
78+
enc3: Encoder[T3],
79+
enc4: Encoder[T4],
80+
enc5: Encoder[T5]): Encoder[(T1, T2, T3, T4, T5)] = {
81+
tuple(Seq(enc1, enc2, enc3, enc4, enc5).map(_.asInstanceOf[ExpressionEncoder[_]]))
82+
.asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4, T5)]]
83+
}
84+
85+
private def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = {
86+
assert(encoders.length > 1)
87+
// make sure all encoders are resolved, i.e. `Attribute` has been resolved to `BoundReference`.
88+
assert(encoders.forall(_.constructExpression.find(_.isInstanceOf[Attribute]).isEmpty))
89+
90+
val schema = StructType(encoders.zipWithIndex.map {
91+
case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema)
92+
})
93+
94+
val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")
95+
96+
val extractExpressions = encoders.map {
97+
case e if e.flat => e.extractExpressions.head
98+
case other => CreateStruct(other.extractExpressions)
99+
}.zipWithIndex.map { case (expr, index) =>
100+
expr.transformUp {
101+
case BoundReference(0, t: ObjectType, _) =>
102+
Invoke(
103+
BoundReference(0, ObjectType(cls), true),
104+
s"_${index + 1}",
105+
t)
106+
}
107+
}
108+
109+
val constructExpressions = encoders.zipWithIndex.map { case (enc, index) =>
110+
if (enc.flat) {
111+
enc.constructExpression.transform {
112+
case b: BoundReference => b.copy(ordinal = index)
113+
}
114+
} else {
115+
enc.constructExpression.transformUp {
116+
case BoundReference(ordinal, dt, _) =>
117+
GetInternalRowField(BoundReference(index, enc.schema, true), ordinal, dt)
118+
}
119+
}
120+
}
121+
122+
val constructExpression =
123+
NewInstance(cls, constructExpressions, false, ObjectType(cls))
124+
125+
new ExpressionEncoder[Any](
126+
schema,
127+
false,
128+
extractExpressions,
129+
constructExpression,
130+
ClassTag.apply(cls))
131+
}
132+
133+
134+
def typeTagOfTuple2[T1 : TypeTag, T2 : TypeTag]: TypeTag[(T1, T2)] = typeTag[(T1, T2)]
135+
136+
private def getTypeTag[T](c: Class[T]): TypeTag[T] = {
137+
import scala.reflect.api
138+
139+
// val mirror = runtimeMirror(c.getClassLoader)
140+
val mirror = rootMirror
141+
val sym = mirror.staticClass(c.getName)
142+
val tpe = sym.selfType
143+
TypeTag(mirror, new api.TypeCreator {
144+
def apply[U <: api.Universe with Singleton](m: api.Mirror[U]) =
145+
if (m eq mirror) tpe.asInstanceOf[U # Type]
146+
else throw new IllegalArgumentException(
147+
s"Type tag defined in $mirror cannot be migrated to other mirrors.")
148+
})
149+
}
150+
151+
def forTuple2[T1, T2](c1: Class[T1], c2: Class[T2]): Encoder[(T1, T2)] = {
152+
implicit val typeTag1 = getTypeTag(c1)
153+
implicit val typeTag2 = getTypeTag(c2)
154+
ExpressionEncoder[(T1, T2)]()
155+
}
156+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,3 +491,24 @@ case class CreateExternalRow(children: Seq[Expression]) extends Expression {
491491
s"final ${classOf[Row].getName} ${ev.value} = new $rowClass($values);"
492492
}
493493
}
494+
495+
case class GetInternalRowField(child: Expression, ordinal: Int, dataType: DataType)
496+
extends UnaryExpression {
497+
498+
override def nullable: Boolean = true
499+
500+
override def eval(input: InternalRow): Any =
501+
throw new UnsupportedOperationException("Only code-generated evaluation is supported")
502+
503+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
504+
val row = child.gen(ctx)
505+
s"""
506+
${row.code}
507+
final boolean ${ev.isNull} = ${row.isNull};
508+
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
509+
if (!${ev.isNull}) {
510+
${ev.value} = ${ctx.getValue(row.value, dataType, ordinal.toString)};
511+
}
512+
"""
513+
}
514+
}

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 121 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,13 @@
1717

1818
package org.apache.spark.sql
1919

20+
import scala.collection.JavaConverters._
21+
2022
import org.apache.spark.annotation.Experimental
2123
import org.apache.spark.rdd.RDD
2224
import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias
25+
import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, _}
26+
2327
import org.apache.spark.sql.catalyst.encoders._
2428
import org.apache.spark.sql.catalyst.expressions._
2529
import org.apache.spark.sql.catalyst.plans.Inner
@@ -151,18 +155,37 @@ class Dataset[T] private[sql](
151155
def transform[U](t: Dataset[T] => Dataset[U]): Dataset[U] = t(this)
152156

153157
/**
158+
* (Scala-specific)
154159
* Returns a new [[Dataset]] that only contains elements where `func` returns `true`.
155160
* @since 1.6.0
156161
*/
157162
def filter(func: T => Boolean): Dataset[T] = mapPartitions(_.filter(func))
158163

159164
/**
165+
* (Java-specific)
166+
* Returns a new [[Dataset]] that only contains elements where `func` returns `true`.
167+
* @since 1.6.0
168+
*/
169+
def filter(func: JFunction[T, java.lang.Boolean]): Dataset[T] =
170+
filter(t => func.call(t).booleanValue())
171+
172+
/**
173+
* (Scala-specific)
160174
* Returns a new [[Dataset]] that contains the result of applying `func` to each element.
161175
* @since 1.6.0
162176
*/
163177
def map[U : Encoder](func: T => U): Dataset[U] = mapPartitions(_.map(func))
164178

165179
/**
180+
* (Java-specific)
181+
* Returns a new [[Dataset]] that contains the result of applying `func` to each element.
182+
* @since 1.6.0
183+
*/
184+
def map[U](func: JFunction[T, U], encoder: Encoder[U]): Dataset[U] =
185+
map(t => func.call(t))(encoder)
186+
187+
/**
188+
* (Scala-specific)
166189
* Returns a new [[Dataset]] that contains the result of applying `func` to each element.
167190
* @since 1.6.0
168191
*/
@@ -177,37 +200,93 @@ class Dataset[T] private[sql](
177200
logicalPlan))
178201
}
179202

203+
/**
204+
* (Java-specific)
205+
* Returns a new [[Dataset]] that contains the result of applying `func` to each element.
206+
* @since 1.6.0
207+
*/
208+
def mapPartitions[U](
209+
f: FlatMapFunction[java.util.Iterator[T], U],
210+
encoder: Encoder[U]): Dataset[U] = {
211+
val func: (Iterator[T]) => Iterator[U] = x => f.call(x.asJava).iterator().asScala
212+
mapPartitions(func)(encoder)
213+
}
214+
215+
/**
216+
* (Scala-specific)
217+
* Returns a new [[Dataset]] by first applying a function to all elements of this [[Dataset]],
218+
* and then flattening the results.
219+
* @since 1.6.0
220+
*/
180221
def flatMap[U : Encoder](func: T => TraversableOnce[U]): Dataset[U] =
181222
mapPartitions(_.flatMap(func))
182223

224+
/**
225+
* (Java-specific)
226+
* Returns a new [[Dataset]] by first applying a function to all elements of this [[Dataset]],
227+
* and then flattening the results.
228+
* @since 1.6.0
229+
*/
230+
def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
231+
val func: (T) => Iterable[U] = x => f.call(x).asScala
232+
flatMap(func)(encoder)
233+
}
234+
183235
/* ************** *
184236
* Side effects *
185237
* ************** */
186238

187239
/**
240+
* (Scala-specific)
188241
* Runs `func` on each element of this Dataset.
189242
* @since 1.6.0
190243
*/
191244
def foreach(func: T => Unit): Unit = rdd.foreach(func)
192245

193246
/**
247+
* (Java-specific)
248+
* Runs `func` on each element of this Dataset.
249+
* @since 1.6.0
250+
*/
251+
def foreach(func: VoidFunction[T]): Unit = foreach(func.call(_))
252+
253+
/**
254+
* (Scala-specific)
194255
* Runs `func` on each partition of this Dataset.
195256
* @since 1.6.0
196257
*/
197258
def foreachPartition(func: Iterator[T] => Unit): Unit = rdd.foreachPartition(func)
198259

260+
/**
261+
* (Java-specific)
262+
* Runs `func` on each partition of this Dataset.
263+
* @since 1.6.0
264+
*/
265+
def foreachPartition(func: VoidFunction[java.util.Iterator[T]]): Unit =
266+
foreachPartition(it => func.call(it.asJava))
267+
199268
/* ************* *
200269
* Aggregation *
201270
* ************* */
202271

203272
/**
273+
* (Scala-specific)
204274
* Reduces the elements of this Dataset using the specified binary function. The given function
205275
* must be commutative and associative or the result may be non-deterministic.
206276
* @since 1.6.0
207277
*/
208278
def reduce(func: (T, T) => T): T = rdd.reduce(func)
209279

210280
/**
281+
* (Java-specific)
282+
* Reduces the elements of this Dataset using the specified binary function. The given function
283+
* must be commutative and associative or the result may be non-deterministic.
284+
* @since 1.6.0
285+
*/
286+
def reduce(func: JFunction2[T, T, T]): T = reduce(func.call(_, _))
287+
288+
/**
289+
* (Scala-specific)
211290
* Aggregates the elements of each partition, and then the results for all the partitions, using a
212291
* given associative and commutative function and a neutral "zero value".
213292
*
@@ -221,6 +300,15 @@ class Dataset[T] private[sql](
221300
def fold(zeroValue: T)(op: (T, T) => T): T = rdd.fold(zeroValue)(op)
222301

223302
/**
303+
* (Java-specific)
304+
* Aggregates the elements of each partition, and then the results for all the partitions, using a
305+
* given associative and commutative function and a neutral "zero value".
306+
* @since 1.6.0
307+
*/
308+
def fold(zeroValue: T, func: JFunction2[T, T, T]): T = fold(zeroValue)(func.call(_, _))
309+
310+
/**
311+
* (Scala-specific)
224312
* Returns a [[GroupedDataset]] where the data is grouped by the given key function.
225313
* @since 1.6.0
226314
*/
@@ -258,6 +346,14 @@ class Dataset[T] private[sql](
258346
keyAttributes)
259347
}
260348

349+
/**
350+
* (Java-specific)
351+
* Returns a [[GroupedDataset]] where the data is grouped by the given key function.
352+
* @since 1.6.0
353+
*/
354+
def groupBy[K](f: JFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] =
355+
groupBy(f.call(_))(encoder)
356+
261357
/* ****************** *
262358
* Typed Relational *
263359
* ****************** */
@@ -267,8 +363,7 @@ class Dataset[T] private[sql](
267363
* {{{
268364
* df.select($"colA", $"colB" + 1)
269365
* }}}
270-
* @group dfops
271-
* @since 1.3.0
366+
* @since 1.6.0
272367
*/
273368
// Copied from Dataframe to make sure we don't have invalid overloads.
274369
@scala.annotation.varargs
@@ -279,7 +374,7 @@ class Dataset[T] private[sql](
279374
*
280375
* {{{
281376
* val ds = Seq(1, 2, 3).toDS()
282-
* val newDS = ds.select(e[Int]("value + 1"))
377+
* val newDS = ds.select(expr("value + 1").as[Int])
283378
* }}}
284379
* @since 1.6.0
285380
*/
@@ -405,6 +500,8 @@ class Dataset[T] private[sql](
405500
* This type of join can be useful both for preserving type-safety with the original object
406501
* types as well as working with relational data where either side of the join has column
407502
* names in common.
503+
*
504+
* @since 1.6.0
408505
*/
409506
def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = {
410507
val left = this.logicalPlan
@@ -438,12 +535,31 @@ class Dataset[T] private[sql](
438535
* Gather to Driver Actions *
439536
* ************************** */
440537

441-
/** Returns the first element in this [[Dataset]]. */
538+
/**
539+
* Returns the first element in this [[Dataset]].
540+
* @since 1.6.0
541+
*/
442542
def first(): T = rdd.first()
443543

444-
/** Collects the elements to an Array. */
544+
/**
545+
* Collects the elements to an Array.
546+
* @since 1.6.0
547+
*/
445548
def collect(): Array[T] = rdd.collect()
446549

550+
/**
551+
* (Java-specific)
552+
* Collects the elements to a Java list.
553+
*
554+
* Due to the incompatibility problem between Scala and Java, the return type of [[collect()]] at
555+
* Java side is `java.lang.Object`, which is not easy to use. Java user can use this method
556+
* instead and keep the generic type for result.
557+
*
558+
* @since 1.6.0
559+
*/
560+
def collectAsList(): java.util.List[T] =
561+
rdd.collect().toSeq.asJava
562+
447563
/** Returns the first `num` elements of this [[Dataset]] as an Array. */
448564
def take(num: Int): Array[T] = rdd.take(num)
449565

0 commit comments

Comments
 (0)