Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ object ScalaReflection extends ScalaReflection {
* Array[T]. Special handling is performed for primitive types to map them back to their raw
* JVM form instead of the Scala Array that handles auto boxing.
*/
private def arrayClassFor(tpe: `Type`): DataType = ScalaReflectionLock.synchronized {
private def arrayClassFor(tpe: `Type`): ObjectType = ScalaReflectionLock.synchronized {
val cls = tpe match {
case t if t <:< definitions.IntTpe => classOf[Array[Int]]
case t if t <:< definitions.LongTpe => classOf[Array[Long]]
Expand Down Expand Up @@ -178,15 +178,17 @@ object ScalaReflection extends ScalaReflection {
* is [a: int, b: long], then we will hit runtime error and say that we can't construct class
* `Data` with int and long, because we lost the information that `b` should be a string.
*
* This method help us "remember" the required data type by adding a `UpCast`. Note that we
* don't need to cast struct type because there must be `UnresolvedExtractValue` or
* `GetStructField` wrapping it, thus we only need to handle leaf type.
* This method help us "remember" the required data type by adding a `UpCast`. Note that we
* only need to do this for leaf nodes.
*/
def upCastToExpectedType(
expr: Expression,
expected: DataType,
walkedTypePath: Seq[String]): Expression = expected match {
case _: StructType => expr
case _: ArrayType => expr
// TODO: ideally we should also skip MapType, but nested StructType inside MapType is rare and
// it's not trivial to support by-name resolution for StructType inside MapType.
case _ => UpCast(expr, expected, walkedTypePath)
}

Expand Down Expand Up @@ -265,42 +267,48 @@ object ScalaReflection extends ScalaReflection {

case t if t <:< localTypeOf[Array[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
val Schema(_, elementNullable) = schemaFor(elementType)
val className = getClassNameFromType(elementType)
val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath

// TODO: add runtime null check for primitive array
val primitiveMethod = elementType match {
case t if t <:< definitions.IntTpe => Some("toIntArray")
case t if t <:< definitions.LongTpe => Some("toLongArray")
case t if t <:< definitions.DoubleTpe => Some("toDoubleArray")
case t if t <:< definitions.FloatTpe => Some("toFloatArray")
case t if t <:< definitions.ShortTpe => Some("toShortArray")
case t if t <:< definitions.ByteTpe => Some("toByteArray")
case t if t <:< definitions.BooleanTpe => Some("toBooleanArray")
case _ => None
val mapFunction: Expression => Expression = p => {
val converter = deserializerFor(elementType, Some(p), newTypePath)
if (elementNullable) {
converter
} else {
AssertNotNull(converter, newTypePath)
}
}

primitiveMethod.map { method =>
Invoke(getPath, method, arrayClassFor(elementType))
}.getOrElse {
val className = getClassNameFromType(elementType)
val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath
Invoke(
MapObjects(
p => deserializerFor(elementType, Some(p), newTypePath),
getPath,
schemaFor(elementType).dataType),
"array",
arrayClassFor(elementType))
val arrayData = UnresolvedMapObjects(mapFunction, getPath)
val arrayCls = arrayClassFor(elementType)

if (elementNullable) {
Invoke(arrayData, "array", arrayCls)
} else {
val primitiveMethod = elementType match {
case t if t <:< definitions.IntTpe => "toIntArray"
case t if t <:< definitions.LongTpe => "toLongArray"
case t if t <:< definitions.DoubleTpe => "toDoubleArray"
case t if t <:< definitions.FloatTpe => "toFloatArray"
case t if t <:< definitions.ShortTpe => "toShortArray"
case t if t <:< definitions.ByteTpe => "toByteArray"
case t if t <:< definitions.BooleanTpe => "toBooleanArray"
case other => throw new IllegalStateException("expect primitive array element type " +
"but got " + other)
}
Invoke(arrayData, primitiveMethod, arrayCls)
}

case t if t <:< localTypeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
val Schema(dataType, nullable) = schemaFor(elementType)
val Schema(_, elementNullable) = schemaFor(elementType)
val className = getClassNameFromType(elementType)
val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath

val mapFunction: Expression => Expression = p => {
val converter = deserializerFor(elementType, Some(p), newTypePath)
if (nullable) {
if (elementNullable) {
converter
} else {
AssertNotNull(converter, newTypePath)
Expand All @@ -311,7 +319,7 @@ object ScalaReflection extends ScalaReflection {
case NoSymbol => classOf[Seq[_]]
case _ => mirror.runtimeClass(t.typeSymbol.asClass)
}
MapObjects(mapFunction, getPath, dataType, Some(cls))
UnresolvedMapObjects(mapFunction, getPath, Some(cls))

case t if t <:< localTypeOf[Map[_, _]] =>
// TODO: add walked type path for map
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.encoders.OuterScopes
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.objects.NewInstance
import org.apache.spark.sql.catalyst.expressions.objects.{MapObjects, NewInstance, UnresolvedMapObjects}
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification
import org.apache.spark.sql.catalyst.plans._
Expand Down Expand Up @@ -2226,8 +2226,21 @@ class Analyzer(
validateTopLevelTupleFields(deserializer, inputs)
val resolved = resolveExpression(
deserializer, LocalRelation(inputs), throws = true)
validateNestedTupleFields(resolved)
resolved
val result = resolved transformDown {
case UnresolvedMapObjects(func, inputData, cls) if inputData.resolved =>
inputData.dataType match {
case ArrayType(et, _) =>
val expr = MapObjects(func, inputData, et, cls) transformUp {
case UnresolvedExtractValue(child, fieldName) if child.resolved =>
ExtractValue(child, fieldName, resolver)
}
expr
case other =>
throw new AnalysisException("need an array field but got " + other.simpleString)
}
}
validateNestedTupleFields(result)
result
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ object ExtractValue {
case StructType(_) =>
s"Field name should be String Literal, but it's $extraction"
case other =>
s"Can't extract value from $child"
s"Can't extract value from $child: need struct type but got ${other.simpleString}"
}
throw new AnalysisException(errorMsg)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,17 @@ object MapObjects {
}
}

case class UnresolvedMapObjects(
function: Expression => Expression,
child: Expression,
customCollectionCls: Option[Class[_]] = None) extends UnaryExpression with Unevaluable {
override lazy val resolved = false

override def dataType: DataType = customCollectionCls.map(ObjectType.apply).getOrElse {
throw new UnsupportedOperationException("not resolved")
}
}

/**
* Applies the given expression to every element of a collection of items, returning the result
* as an ArrayType or ObjectType. This is similar to a typical map operation, but where the lambda
Expand Down Expand Up @@ -581,17 +592,24 @@ case class MapObjects private(
// collection
val collObjectName = s"${cls.getName}$$.MODULE$$"
val getBuilderVar = s"$collObjectName.newBuilder()"

(s"""${classOf[Builder[_, _]].getName} $builderValue = $getBuilderVar;
$builderValue.sizeHint($dataLength);""",
(
s"""
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

below are style-only changes

${classOf[Builder[_, _]].getName} $builderValue = $getBuilderVar;
$builderValue.sizeHint($dataLength);
""",
genValue => s"$builderValue.$$plus$$eq($genValue);",
s"(${cls.getName}) $builderValue.result();")
s"(${cls.getName}) $builderValue.result();"
)
case None =>
// array
(s"""$convertedType[] $convertedArray = null;
$convertedArray = $arrayConstructor;""",
(
s"""
$convertedType[] $convertedArray = null;
$convertedArray = $arrayConstructor;
""",
genValue => s"$convertedArray[$loopIndex] = $genValue;",
s"new ${classOf[GenericArrayData].getName}($convertedArray);")
s"new ${classOf[GenericArrayData].getName}($convertedArray);"
)
}

val code = s"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ case class StringIntClass(a: String, b: Int)

case class ComplexClass(a: Long, b: StringLongClass)

case class ArrayClass(arr: Seq[StringIntClass])

case class NestedArrayClass(nestedArr: Array[ArrayClass])

class EncoderResolutionSuite extends PlanTest {
private val str = UTF8String.fromString("hello")

Expand Down Expand Up @@ -62,6 +66,54 @@ class EncoderResolutionSuite extends PlanTest {
encoder.resolveAndBind(attrs).fromRow(InternalRow(InternalRow(str, 1.toByte), 2))
}

test("real type doesn't match encoder schema but they are compatible: array") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason we don't have any encodeDecodeTest?

Copy link
Contributor Author

@cloud-fan cloud-fan Mar 25, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

encodeDecodeTest is a round trip test so that the type and schema match exactly. I added an end-to-end test in DatasetSuite

val encoder = ExpressionEncoder[ArrayClass]
val attrs = Seq('arr.array(new StructType().add("a", "int").add("b", "int").add("c", "int")))
val array = new GenericArrayData(Array(InternalRow(1, 2, 3)))
encoder.resolveAndBind(attrs).fromRow(InternalRow(array))
}

test("real type doesn't match encoder schema but they are compatible: nested array") {
val encoder = ExpressionEncoder[NestedArrayClass]
val et = new StructType().add("arr", ArrayType(
new StructType().add("a", "int").add("b", "int").add("c", "int")))
val attrs = Seq('nestedArr.array(et))
val innerArr = new GenericArrayData(Array(InternalRow(1, 2, 3)))
val outerArr = new GenericArrayData(Array(InternalRow(innerArr)))
encoder.resolveAndBind(attrs).fromRow(InternalRow(outerArr))
}

test("the real type is not compatible with encoder schema: non-array field") {
val encoder = ExpressionEncoder[ArrayClass]
val attrs = Seq('arr.int)
assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message ==
"need an array field but got int")
}

test("the real type is not compatible with encoder schema: array element type") {
val encoder = ExpressionEncoder[ArrayClass]
val attrs = Seq('arr.array(new StructType().add("c", "int")))
assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message ==
"No such struct field a in c")
}

test("the real type is not compatible with encoder schema: nested array element type") {
val encoder = ExpressionEncoder[NestedArrayClass]

withClue("inner element is not array") {
val attrs = Seq('nestedArr.array(new StructType().add("arr", "int")))
assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message ==
"need an array field but got int")
}

withClue("nested array element type is not compatible") {
val attrs = Seq('nestedArr.array(new StructType()
.add("arr", ArrayType(new StructType().add("c", "int")))))
assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message ==
"No such struct field a in c")
}
}

test("nullability of array type element should not fail analysis") {
val encoder = ExpressionEncoder[Seq[Int]]
val attrs = 'a.array(IntegerType) :: Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
private[sql] class ReduceAggregator[T: Encoder](func: (T, T) => T)
extends Aggregator[T, (Boolean, T), T] {

private val encoder = implicitly[Encoder[T]]
@transient private val encoder = implicitly[Encoder[T]]

override def zero: (Boolean, T) = (false, null.asInstanceOf[T])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
assert(ds.take(2) === Array(ClassData("a", 1), ClassData("b", 2)))
}

test("as seq of case class - reorder fields by name") {
val df = spark.range(3).select(array(struct($"id".cast("int").as("b"), lit("a").as("a"))))
val ds = df.as[Seq[ClassData]]
assert(ds.collect() === Array(
Seq(ClassData("a", 0)),
Seq(ClassData("a", 1)),
Seq(ClassData("a", 2))))
}

test("map") {
val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS()
checkDataset(
Expand Down