@@ -21,7 +21,7 @@ import scala.collection.mutable.ArrayBuffer
2121
2222import org .apache .spark .sql .catalyst .expressions ._
2323import org .apache .spark .sql .catalyst .expressions .aggregate .NoOp
24- import org .apache .spark .sql .types .{ StringType , StructType , DataType }
24+ import org .apache .spark .sql .types ._
2525
2626
2727/**
@@ -36,34 +36,94 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
3636 protected def bind (in : Seq [Expression ], inputSchema : Seq [Attribute ]): Seq [Expression ] =
3737 in.map(BindReferences .bindReference(_, inputSchema))
3838
39- private def genUpdater (
39+ private def createCodeForStruct (
4040 ctx : CodeGenContext ,
41- setter : String ,
42- dataType : DataType ,
43- ordinal : Int ,
44- value : String ): String = {
45- dataType match {
46- case struct : StructType =>
47- val rowTerm = ctx.freshName(" row" )
48- val updates = struct.map(_.dataType).zipWithIndex.map { case (dt, i) =>
49- val colTerm = ctx.freshName(" col" )
50- s """
51- if ( $value.isNullAt( $i)) {
52- $rowTerm.setNullAt( $i);
53- } else {
54- ${ctx.javaType(dt)} $colTerm = ${ctx.getValue(value, dt, s " $i" )};
55- ${genUpdater(ctx, rowTerm, dt, i, colTerm)};
56- }
57- """
58- }.mkString(" \n " )
59- s """
60- $genericMutableRowType $rowTerm = new $genericMutableRowType( ${struct.fields.length});
61- $updates
62- $setter.update( $ordinal, $rowTerm.copy());
63- """
64- case _ =>
65- ctx.setColumn(setter, dataType, ordinal, value)
66- }
41+ input : String ,
42+ schema : StructType ): GeneratedExpressionCode = {
43+ val tmp = ctx.freshName(" tmp" )
44+ val output = ctx.freshName(" safeRow" )
45+ val values = ctx.freshName(" values" )
46+ val rowClass = classOf [GenericInternalRow ].getName
47+
48+ val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) =>
49+ val converter = convertToSafe(ctx, ctx.getValue(tmp, dt, i.toString), dt)
50+ s """
51+ if (! $tmp.isNullAt( $i)) {
52+ ${converter.code}
53+ $values[ $i] = ${converter.primitive};
54+ }
55+ """
56+ }.mkString(" \n " )
57+
58+ val code = s """
59+ final InternalRow $tmp = $input;
60+ final Object[] $values = new Object[ ${schema.length}];
61+ $fieldWriters
62+ final InternalRow $output = new $rowClass( $values);
63+ """
64+
65+ GeneratedExpressionCode (code, " false" , output)
66+ }
67+
68+ private def createCodeForArray (
69+ ctx : CodeGenContext ,
70+ input : String ,
71+ elementType : DataType ): GeneratedExpressionCode = {
72+ val tmp = ctx.freshName(" tmp" )
73+ val output = ctx.freshName(" safeArray" )
74+ val values = ctx.freshName(" values" )
75+ val numElements = ctx.freshName(" numElements" )
76+ val index = ctx.freshName(" index" )
77+ val arrayClass = classOf [GenericArrayData ].getName
78+
79+ val elementConverter = convertToSafe(ctx, ctx.getValue(tmp, elementType, index), elementType)
80+ val code = s """
81+ final ArrayData $tmp = $input;
82+ final int $numElements = $tmp.numElements();
83+ final Object[] $values = new Object[ $numElements];
84+ for (int $index = 0; $index < $numElements; $index++) {
85+ if (! $tmp.isNullAt( $index)) {
86+ ${elementConverter.code}
87+ $values[ $index] = ${elementConverter.primitive};
88+ }
89+ }
90+ final ArrayData $output = new $arrayClass( $values);
91+ """
92+
93+ GeneratedExpressionCode (code, " false" , output)
94+ }
95+
96+ private def createCodeForMap (
97+ ctx : CodeGenContext ,
98+ input : String ,
99+ keyType : DataType ,
100+ valueType : DataType ): GeneratedExpressionCode = {
101+ val tmp = ctx.freshName(" tmp" )
102+ val output = ctx.freshName(" safeMap" )
103+ val mapClass = classOf [ArrayBasedMapData ].getName
104+
105+ val keyConverter = createCodeForArray(ctx, s " $tmp.keyArray() " , keyType)
106+ val valueConverter = createCodeForArray(ctx, s " $tmp.valueArray() " , valueType)
107+ val code = s """
108+ final MapData $tmp = $input;
109+ ${keyConverter.code}
110+ ${valueConverter.code}
111+ final MapData $output = new $mapClass( ${keyConverter.primitive}, ${valueConverter.primitive});
112+ """
113+
114+ GeneratedExpressionCode (code, " false" , output)
115+ }
116+
117+ private def convertToSafe (
118+ ctx : CodeGenContext ,
119+ input : String ,
120+ dataType : DataType ): GeneratedExpressionCode = dataType match {
121+ case s : StructType => createCodeForStruct(ctx, input, s)
122+ case ArrayType (elementType, _) => createCodeForArray(ctx, input, elementType)
123+ case MapType (keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType)
124+ // UTF8String act as a pointer if it's inside UnsafeRow, so copy it to make it safe.
125+ case StringType => GeneratedExpressionCode (" " , " false" , s " $input.clone() " )
126+ case _ => GeneratedExpressionCode (" " , " false" , input)
67127 }
68128
69129 protected def create (expressions : Seq [Expression ]): Projection = {
@@ -72,12 +132,14 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
72132 case (NoOp , _) => " "
73133 case (e, i) =>
74134 val evaluationCode = e.gen(ctx)
135+ val converter = convertToSafe(ctx, evaluationCode.primitive, e.dataType)
75136 evaluationCode.code +
76137 s """
77138 if ( ${evaluationCode.isNull}) {
78139 mutableRow.setNullAt( $i);
79140 } else {
80- ${genUpdater(ctx, " mutableRow" , e.dataType, i, evaluationCode.primitive)};
141+ ${converter.code}
142+ ${ctx.setColumn(" mutableRow" , e.dataType, i, converter.primitive)};
81143 }
82144 """
83145 }
0 commit comments