From ae356d103cd8264b449ad7b682c2af21fed6e2a4 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Mon, 1 Feb 2016 17:42:34 -0800 Subject: [PATCH 1/2] [SPARK-13094][SQL] Add encoders for seq/array of primitives --- .../org/apache/spark/sql/SQLImplicits.scala | 56 ++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index ab414799f1a42..34a7aca5c156b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -39,6 +39,8 @@ abstract class SQLImplicits { /** @since 1.6.0 */ implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder() + // Primitives + /** @since 1.6.0 */ implicit def newIntEncoder: Encoder[Int] = ExpressionEncoder() @@ -56,13 +58,65 @@ abstract class SQLImplicits { /** @since 1.6.0 */ implicit def newShortEncoder: Encoder[Short] = ExpressionEncoder() - /** @since 1.6.0 */ + /** @since 1.6.0 */ implicit def newBooleanEncoder: Encoder[Boolean] = ExpressionEncoder() /** @since 1.6.0 */ implicit def newStringEncoder: Encoder[String] = ExpressionEncoder() + // Seqs + + /** @since 1.6.1 */ + implicit def newIntSeqEncoder: Encoder[Seq[Int]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newLongSeqEncoder: Encoder[Seq[Long]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newDoubleSeqEncoder: Encoder[Seq[Double]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newFloatSeqEncoder: Encoder[Seq[Float]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newByteSeqEncoder: Encoder[Seq[Byte]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newShortSeqEncoder: Encoder[Seq[Short]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newBooleanSeqEncoder: Encoder[Seq[Boolean]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newStringSeqEncoder: Encoder[Seq[String]] = ExpressionEncoder() + + // Arrays + + /** @since 1.6.1 */ + implicit def newIntArrayEncoder: Encoder[Array[Int]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newLongArrayEncoder: Encoder[Array[Long]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newDoubleArrayEncoder: Encoder[Array[Double]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newFloatArrayEncoder: Encoder[Array[Float]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newByteArrayEncoder: Encoder[Array[Byte]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newShortArrayEncoder: Encoder[Array[Short]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newBooleanArrayEncoder: Encoder[Array[Boolean]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newStringArrayEncoder: Encoder[Array[String]] = ExpressionEncoder() + /** * Creates a [[Dataset]] from an RDD. * @since 1.6.0 From 0b5f03c41ef7374c0a70771acb198ad61799fdec Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Mon, 1 Feb 2016 21:30:08 -0800 Subject: [PATCH 2/2] product and tests --- .../org/apache/spark/sql/SQLImplicits.scala | 7 ++++++ .../spark/sql/DatasetPrimitiveSuite.scala | 22 +++++++++++++++++++ .../org/apache/spark/sql/QueryTest.scala | 8 ++++++- 3 files changed, 36 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 34a7aca5c156b..16c4095db722a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -91,6 +91,9 @@ abstract class SQLImplicits { /** @since 1.6.1 */ implicit def newStringSeqEncoder: Encoder[Seq[String]] = ExpressionEncoder() + /** @since 1.6.1 */ + implicit def newProductSeqEncoder[A <: Product : TypeTag]: Encoder[Seq[A]] = ExpressionEncoder() + // Arrays /** @since 1.6.1 */ @@ -117,6 +120,10 @@ abstract class SQLImplicits { /** @since 1.6.1 */ implicit def newStringArrayEncoder: Encoder[Array[String]] = ExpressionEncoder() + /** @since 1.6.1 */ + implicit def newProductArrayEncoder[A <: Product : TypeTag]: Encoder[Array[A]] = + ExpressionEncoder() + /** * Creates a [[Dataset]] from an RDD. * @since 1.6.0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index f75d0961823c4..243d13b19d6cd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -105,4 +105,26 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { agged, "1", "abc", "3", "xyz", "5", "hello") } + + test("Arrays and Lists") { + checkAnswer(Seq(Seq(1)).toDS(), Seq(1)) + checkAnswer(Seq(Seq(1.toLong)).toDS(), Seq(1.toLong)) + checkAnswer(Seq(Seq(1.toDouble)).toDS(), Seq(1.toDouble)) + checkAnswer(Seq(Seq(1.toFloat)).toDS(), Seq(1.toFloat)) + checkAnswer(Seq(Seq(1.toByte)).toDS(), Seq(1.toByte)) + checkAnswer(Seq(Seq(1.toShort)).toDS(), Seq(1.toShort)) + checkAnswer(Seq(Seq(true)).toDS(), Seq(true)) + checkAnswer(Seq(Seq("test")).toDS(), Seq("test")) + checkAnswer(Seq(Seq(Tuple1(1))).toDS(), Seq(Tuple1(1))) + + checkAnswer(Seq(Array(1)).toDS(), Array(1)) + checkAnswer(Seq(Array(1.toLong)).toDS(), Array(1.toLong)) + checkAnswer(Seq(Array(1.toDouble)).toDS(), Array(1.toDouble)) + checkAnswer(Seq(Array(1.toFloat)).toDS(), Array(1.toFloat)) + checkAnswer(Seq(Array(1.toByte)).toDS(), Array(1.toByte)) + checkAnswer(Seq(Array(1.toShort)).toDS(), Array(1.toShort)) + checkAnswer(Seq(Array(true)).toDS(), Array(true)) + checkAnswer(Seq(Array("test")).toDS(), Array("test")) + checkAnswer(Seq(Array(Tuple1(1))).toDS(), Array(Tuple1(1))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index ce12f788b786c..dbf5e6404881e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -95,7 +95,13 @@ abstract class QueryTest extends PlanTest { """.stripMargin, e) } - if (decoded != expectedAnswer.toSet) { + // Handle the case where the return type is an array + val isArray = decoded.headOption.map(_.getClass.isArray).getOrElse(false) + def normalEquality = decoded == expectedAnswer.toSet + def expectedAsSeq = expectedAnswer.map(_.asInstanceOf[Array[_]].toSeq).toSet + def decodedAsSeq = decoded.map(_.asInstanceOf[Array[_]].toSeq) + + if (!((isArray && expectedAsSeq == decodedAsSeq) || normalEquality)) { val expected = expectedAnswer.toSet.toSeq.map((a: Any) => a.toString).sorted val actual = decoded.toSet.toSeq.map((a: Any) => a.toString).sorted