Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,7 @@ object ScalaReflection extends ScalaReflection {
udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
Nil,
dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil)
Invoke(obj, "serialize", udt, inputObject :: Nil)

case t if UDTRegistration.exists(getClassNameFromType(t)) =>
val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance()
Expand All @@ -577,7 +577,7 @@ object ScalaReflection extends ScalaReflection {
udt.getClass,
Nil,
dataType = ObjectType(udt.getClass))
Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil)
Invoke(obj, "serialize", udt, inputObject :: Nil)

case t if definedByConstructorParams(t) =>
val params = getConstructorParameters(t)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatalystConf}
import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogRelation, InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.encoders.OuterScopes
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, OuterScopes}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.objects.NewInstance
Expand Down Expand Up @@ -1887,7 +1887,9 @@ class Analyzer(
val unbound = deserializer transform {
case b: BoundReference => inputs(b.ordinal)
}
resolveExpression(unbound, LocalRelation(inputs), throws = true)

val resolved = resolveExpression(unbound, LocalRelation(inputs), throws = true)
ExpressionEncoder.checkNullFlag(inputs, resolved)
}
}
}
Expand Down Expand Up @@ -2129,4 +2131,3 @@ object TimeWindowing extends Rule[LogicalPlan] {
}
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection
import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, NewInstance}
import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
import org.apache.spark.sql.types.{ObjectType, StructField, StructType}
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

/**
Expand All @@ -52,7 +52,7 @@ object ExpressionEncoder {
val cls = mirror.runtimeClass(tpe)
val flat = !ScalaReflection.definedByConstructorParams(tpe)

val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = false)
val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = true)
val serializer = ScalaReflection.serializerFor[T](inputObject)
val deserializer = ScalaReflection.deserializerFor[T]

Expand Down Expand Up @@ -107,7 +107,10 @@ object ExpressionEncoder {

val serializer = encoders.map {
case e if e.flat => e.serializer.head
case other => CreateStruct(other.serializer)
case other =>
val inputObject = other.serializer.head.find(_.isInstanceOf[BoundReference]).get
val struct = CreateStruct(other.serializer)
If(IsNull(inputObject), Literal.create(null, struct.dataType), struct)
}.zipWithIndex.map { case (expr, index) =>
expr.transformUp {
case BoundReference(0, t, _) =>
Expand All @@ -125,12 +128,13 @@ object ExpressionEncoder {
}
} else {
val input = BoundReference(index, enc.schema, nullable = true)
enc.deserializer.transformUp {
val deserializer = enc.deserializer.transformUp {
case UnresolvedAttribute(nameParts) =>
assert(nameParts.length == 1)
UnresolvedExtractValue(input, Literal(nameParts.head))
case BoundReference(ordinal, dt, _) => GetStructField(input, ordinal)
}
If(IsNull(input), Literal.create(null, deserializer.dataType), deserializer)
}
}

Expand Down Expand Up @@ -170,6 +174,29 @@ object ExpressionEncoder {
e4: ExpressionEncoder[T4],
e5: ExpressionEncoder[T5]): ExpressionEncoder[(T1, T2, T3, T4, T5)] =
tuple(Seq(e1, e2, e3, e4, e5)).asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4, T5)]]

private val nullFlagName = "is_null_obj"
private val nullFlagMeta = new MetadataBuilder().putBoolean(nullFlagName, true).build()

def nullFlagColumn(inputObject: Expression): NamedExpression = {
Alias(IsNull(inputObject), nullFlagName)(explicitMetadata = Some(nullFlagMeta))
}

def isNullFlagColumn(a: Attribute): Boolean = {
a.dataType == BooleanType && a.metadata.contains(nullFlagName)
}

def checkNullFlag(input: Seq[Attribute], output: Expression): Expression = {
val nullFlag = input.filter(isNullFlagColumn)
if (nullFlag.isEmpty) {
output
} else if (nullFlag.length == 1) {
val objIsNull = Or(IsNull(nullFlag.head), nullFlag.head)
If(objIsNull, Literal.create(null, output.dataType), output)
} else {
throw new IllegalStateException("more than one null flag columns are found.")
}
}
}

/**
Expand All @@ -194,6 +221,9 @@ case class ExpressionEncoder[T](
@transient
private lazy val extractProjection = GenerateUnsafeProjection.generate(serializer)

@transient
private lazy val serProjWithNullFlag = GenerateUnsafeProjection.generate(serializerWithNullFlag)

@transient
private lazy val inputRow = new GenericMutableRow(1)

Expand All @@ -209,16 +239,24 @@ case class ExpressionEncoder[T](
resolve(attrs, OuterScopes.outerScopes).bind(attrs)
}


/**
* Returns a new set (with unique ids) of [[NamedExpression]] that represent the serialized form
* of this object.
*/
def namedExpressions: Seq[NamedExpression] = schema.map(_.name).zip(serializer).map {
def namedSerializer: Seq[NamedExpression] = schema.map(_.name).zip(serializer).map {
case (_, ne: NamedExpression) => ne.newInstance()
case (name, e) => Alias(e, name)()
}

def serializerWithNullFlag: Seq[NamedExpression] = {
if (flat) {
namedSerializer
} else {
val inputObject = serializer.head.find(_.isInstanceOf[BoundReference]).get
namedSerializer :+ ExpressionEncoder.nullFlagColumn(inputObject)
}
}

/**
* Returns an encoded version of `t` as a Spark SQL row. Note that multiple calls to
* toRow are allowed to return the same actual [[InternalRow]] object. Thus, the caller should
Expand All @@ -233,6 +271,11 @@ case class ExpressionEncoder[T](
s"Error while encoding: $e\n${serializer.map(_.treeString).mkString("\n")}", e)
}

def toRowWithNullFlag(t: T): InternalRow = {
inputRow(0) = t
serProjWithNullFlag(inputRow)
}

/**
* Returns an object of type `T`, extracting the required values from the provided row. Note that
* you must `resolve` and `bind` an encoder to a specific schema before you can call this
Expand Down Expand Up @@ -326,7 +369,8 @@ case class ExpressionEncoder[T](
LocalRelation(schema))
val analyzedPlan = SimpleAnalyzer.execute(plan)
SimpleAnalyzer.checkAnalysis(analyzedPlan)
copy(deserializer = SimplifyCasts(analyzedPlan).expressions.head.children.head)
val newDeserializer = SimplifyCasts(analyzedPlan).expressions.head.children.head
copy(deserializer = ExpressionEncoder.checkNullFlag(schema, newDeserializer))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,8 @@ object CatalystSerde {
DeserializeToObject(deserializer, generateObjAttr[T], child)
}

def deserialize(child: LogicalPlan, encoder: ExpressionEncoder[Row]): DeserializeToObject = {
val deserializer = UnresolvedDeserializer(encoder.deserializer)
DeserializeToObject(deserializer, generateObjAttrForRow(encoder), child)
}

def serialize[T : Encoder](child: LogicalPlan): SerializeFromObject = {
SerializeFromObject(encoderFor[T].namedExpressions, child)
}

def serialize(child: LogicalPlan, encoder: ExpressionEncoder[Row]): SerializeFromObject = {
SerializeFromObject(encoder.namedExpressions, child)
SerializeFromObject(encoderFor[T].serializerWithNullFlag, child)
}

def generateObjAttr[T : Encoder]: Attribute = {
Expand Down Expand Up @@ -128,7 +119,7 @@ object MapPartitionsInR {
schema: StructType,
encoder: ExpressionEncoder[Row],
child: LogicalPlan): LogicalPlan = {
val deserialized = CatalystSerde.deserialize(child, encoder)
val deserialized = CatalystSerde.deserialize(child)(encoder)
val mapped = MapPartitionsInR(
func,
packageNames,
Expand All @@ -137,7 +128,7 @@ object MapPartitionsInR {
schema,
CatalystSerde.generateObjAttrForRow(RowEncoder(schema)),
deserialized)
CatalystSerde.serialize(mapped, RowEncoder(schema))
CatalystSerde.serialize(mapped)(RowEncoder(schema))
}
}

Expand Down Expand Up @@ -185,7 +176,7 @@ object AppendColumns {
new AppendColumns(
func.asInstanceOf[Any => Any],
UnresolvedDeserializer(encoderFor[T].deserializer),
encoderFor[U].namedExpressions,
encoderFor[U].serializerWithNullFlag,
child)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ object Metadata {
JString(x)
case x: Metadata =>
toJsonValue(x.map)
case null => JNull
case other =>
throw new RuntimeException(s"Do not support type ${other.getClass}.")
}
Expand All @@ -208,6 +209,7 @@ object Metadata {
x.##
case x: Metadata =>
hash(x.map)
case null => 0
case other =>
throw new RuntimeException(s"Do not support type ${other.getClass}.")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
val inputPlan = LocalRelation(attr)
val plan =
Project(Alias(encoder.deserializer, "obj")() :: Nil,
Project(encoder.namedExpressions,
Project(encoder.serializerWithNullFlag,
inputPlan))
assertAnalysisSuccess(plan)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ class EliminateSerializationSuite extends PlanTest {

val expected = AppendColumnsWithObject(
func.asInstanceOf[Any => Any],
productEncoder[(Int, Int)].namedExpressions,
intEncoder.namedExpressions,
productEncoder[(Int, Int)].serializerWithNullFlag,
intEncoder.serializerWithNullFlag,
input).analyze

comparePlans(optimized, expected)
Expand Down
Loading