Skip to content

Commit e9dbd7b

Browse files
committed
add type cast if the real type is different but compatible with encoder schema
1 parent ee21407 commit e9dbd7b

File tree

4 files changed

+180
-8
lines changed

4 files changed

+180
-8
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,10 @@ import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtract
2828
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
2929
import org.apache.spark.sql.catalyst.expressions._
3030
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
31+
import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts
3132
import org.apache.spark.sql.catalyst.InternalRow
3233
import org.apache.spark.sql.catalyst.ScalaReflection
33-
import org.apache.spark.sql.types.{StructField, ObjectType, StructType}
34+
import org.apache.spark.sql.types.{DataType, StructField, ObjectType, StructType}
3435

3536
/**
3637
* A factory for constructing encoders that convert objects and primitives to and from the
@@ -210,26 +211,71 @@ case class ExpressionEncoder[T](
210211
})
211212
}
212213

214+
private def handleStruct(input: Expression, s: StructType): Expression = {
215+
assert(input.isInstanceOf[NewInstance] || input.isInstanceOf[CreateExternalRow])
216+
val children = input.children
217+
assert(children.length == s.length)
218+
219+
val newChildren = children.zip(s.map(_.dataType)).map {
220+
case (child, dt) => typeCast(child, dt)
221+
}
222+
223+
input.withNewChildren(newChildren)
224+
}
225+
226+
private def typeCast(input: Expression, expectedType: DataType): Expression = expectedType match {
227+
case s: StructType =>
228+
var continue = true
229+
input transformDown {
230+
case c: CreateExternalRow if continue =>
231+
continue = false
232+
handleStruct(c, s)
233+
case n: NewInstance if continue =>
234+
continue = false
235+
handleStruct(n, s)
236+
}
237+
238+
case _ =>
239+
var continue = true
240+
input transformDown {
241+
case u: UnresolvedExtractValue if continue =>
242+
continue = false
243+
Cast(u, expectedType)
244+
case g: GetInternalRowField if continue =>
245+
continue = false
246+
Cast(g, expectedType)
247+
case u: UnresolvedAttribute if continue =>
248+
continue = false
249+
Cast(u, expectedType)
250+
case a: AttributeReference if continue =>
251+
continue = false
252+
Cast(a, expectedType)
253+
}
254+
}
255+
213256
/**
214257
* Returns a new copy of this encoder, where the expressions used by `fromRow` are resolved to the
215258
* given schema.
216259
*/
217260
def resolve(
218-
schema: Seq[Attribute],
261+
attrs: Seq[Attribute],
219262
outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = {
220-
val positionToAttribute = AttributeMap.toIndex(schema)
221-
val unbound = fromRowExpression transform {
263+
val positionToAttribute = AttributeMap.toIndex(attrs)
264+
val unbound = fromRowExpression transformUp {
222265
case b: BoundReference => positionToAttribute(b.ordinal)
223266
}
224267

225-
val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema))
268+
val withTypeCast = typeCast(unbound, if (flat) schema.head.dataType else schema)
269+
270+
val plan = Project(Alias(withTypeCast, "")() :: Nil, LocalRelation(attrs))
226271
val analyzedPlan = SimpleAnalyzer.execute(plan)
272+
val optimizedPlan = SimplifyCasts(analyzedPlan)
227273

228274
// In order to construct instances of inner classes (for example those declared in a REPL cell),
229275
// we need an instance of the outer scope. This rule substitues those outer objects into
230276
// expressions that are missing them by looking up the name in the SQLContexts `outerScopes`
231277
// registry.
232-
copy(fromRowExpression = analyzedPlan.expressions.head.children.head transform {
278+
copy(fromRowExpression = optimizedPlan.expressions.head.children.head transform {
233279
case n: NewInstance if n.outerPointer.isEmpty && n.cls.isMemberClass =>
234280
val outer = outerScopes.get(n.cls.getDeclaringClass.getName)
235281
if (outer == null) {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ case class CreateStruct(children: Seq[Expression]) extends Expression {
126126
case class CreateNamedStruct(children: Seq[Expression]) extends Expression {
127127

128128
/**
129-
* Returns Aliased [[Expressions]] that could be used to construct a flattened version of this
129+
* Returns Aliased [[Expression]]s that could be used to construct a flattened version of this
130130
* StructType.
131131
*/
132132
def flatten: Seq[NamedExpression] = valExprs.zip(names).map {
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.encoders
19+
20+
import org.apache.spark.sql.catalyst.dsl.expressions._
21+
import org.apache.spark.sql.catalyst.expressions._
22+
import org.apache.spark.sql.catalyst.plans.PlanTest
23+
import org.apache.spark.sql.types._
24+
25+
case class StringLongClass(a: String, b: Long)
26+
27+
case class ComplexClass(a: Long, b: StringLongClass)
28+
29+
class EncoderResolveSuite extends PlanTest {
30+
test("real type doesn't match encoder schema but they are compatible: product") {
31+
val encoder = ExpressionEncoder[StringLongClass]
32+
val cls = classOf[StringLongClass]
33+
34+
var attrs = Seq('a.string, 'b.int)
35+
var fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression
36+
var expected: Expression = NewInstance(
37+
cls,
38+
toExternalString('a.string) :: 'b.int.cast(LongType) :: Nil,
39+
false,
40+
ObjectType(cls))
41+
compareExpressions(fromRowExpr, expected)
42+
43+
attrs = Seq('a.int, 'b.long)
44+
fromRowExpr = encoder.resolve(attrs, null).fromRowExpression
45+
expected = NewInstance(
46+
cls,
47+
toExternalString('a.int.cast(StringType)) :: 'b.long :: Nil,
48+
false,
49+
ObjectType(cls))
50+
compareExpressions(fromRowExpr, expected)
51+
}
52+
53+
test("real type doesn't match encoder schema but they are compatible: nested product") {
54+
val encoder = ExpressionEncoder[ComplexClass]
55+
val innerCls = classOf[StringLongClass]
56+
val cls = classOf[ComplexClass]
57+
58+
val structType = new StructType().add("a", IntegerType).add("b", LongType)
59+
val attrs = Seq('a.int, 'b.struct(structType))
60+
val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression
61+
val expected: Expression = NewInstance(
62+
cls,
63+
Seq(
64+
'a.int.cast(LongType),
65+
If(
66+
'b.struct(structType).isNull,
67+
Literal.create(null, ObjectType(innerCls)),
68+
NewInstance(
69+
innerCls,
70+
Seq(
71+
toExternalString(GetStructField(
72+
'b.struct(structType),
73+
structType(0),
74+
0).cast(StringType)),
75+
GetStructField(
76+
'b.struct(structType),
77+
structType(1),
78+
1)),
79+
false,
80+
ObjectType(innerCls))
81+
)),
82+
false,
83+
ObjectType(cls))
84+
compareExpressions(fromRowExpr, expected)
85+
}
86+
87+
test("real type doesn't match encoder schema but they are compatible: tupled encoder") {
88+
val encoder = ExpressionEncoder.tuple(
89+
ExpressionEncoder[StringLongClass],
90+
ExpressionEncoder[Long])
91+
val cls = classOf[StringLongClass]
92+
93+
val structType = new StructType().add("a", StringType).add("b", ByteType, false)
94+
val attrs = Seq('a.struct(structType), 'b.int)
95+
val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression
96+
val expected: Expression = NewInstance(
97+
classOf[Tuple2[_, _]],
98+
Seq(
99+
NewInstance(
100+
cls,
101+
Seq(
102+
toExternalString(GetStructField(
103+
'a.struct(structType),
104+
structType(0),
105+
0)),
106+
GetStructField(
107+
'a.struct(structType),
108+
structType(1),
109+
1).cast(LongType)),
110+
false,
111+
ObjectType(cls)),
112+
'b.int.cast(LongType)),
113+
false,
114+
ObjectType(classOf[Tuple2[_, _]]))
115+
compareExpressions(fromRowExpr, expected)
116+
}
117+
118+
private def toExternalString(e: Expression): Expression = {
119+
Invoke(e, "toString", ObjectType(classOf[String]), Nil)
120+
}
121+
}

sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
384384
Seq((JavaData(1), 1L), (JavaData(2), 1L)))
385385
}
386386

387-
ignore("Java encoder self join") {
387+
test("Java encoder self join") {
388388
implicit val kryoEncoder = Encoders.javaSerialization[JavaData]
389389
val ds = Seq(JavaData(1), JavaData(2)).toDS()
390390
assert(ds.joinWith(ds, lit(true)).collect().toSet ==
@@ -394,6 +394,11 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
394394
(JavaData(2), JavaData(1)),
395395
(JavaData(2), JavaData(2))))
396396
}
397+
398+
test("change encoder with compatible schema") {
399+
val ds = Seq(2 -> 2.toByte, 3 -> 3.toByte).toDF("a", "b").as[ClassData]
400+
assert(ds.collect().toSeq == Seq(ClassData("2", 2), ClassData("3", 3)))
401+
}
397402
}
398403

399404

0 commit comments

Comments
 (0)