@@ -22,27 +22,45 @@ import java.net.Socket
2222import java .nio .channels .Channels
2323import java .util .Locale
2424
25- import net .razorvine .pickle .Pickler
25+ import net .razorvine .pickle .{ Pickler , Unpickler }
2626
2727import org .apache .spark .api .python .DechunkedInputStream
2828import org .apache .spark .internal .Logging
2929import org .apache .spark .security .SocketAuthServer
3030import org .apache .spark .sql .{Column , DataFrame , Row , SparkSession }
31- import org .apache .spark .sql .catalyst .CatalystTypeConverters
31+ import org .apache .spark .sql .catalyst .{ CatalystTypeConverters , InternalRow }
3232import org .apache .spark .sql .catalyst .analysis .FunctionRegistry
33+ import org .apache .spark .sql .catalyst .encoders .ExpressionEncoder
3334import org .apache .spark .sql .catalyst .expressions ._
3435import org .apache .spark .sql .catalyst .expressions .aggregate ._
3536import org .apache .spark .sql .catalyst .parser .CatalystSqlParser
3637import org .apache .spark .sql .execution .{ExplainMode , QueryExecution }
3738import org .apache .spark .sql .execution .arrow .ArrowConverters
3839import org .apache .spark .sql .execution .python .EvaluatePython
3940import org .apache .spark .sql .internal .SQLConf
40- import org .apache .spark .sql .types .DataType
41+ import org .apache .spark .sql .types .{ DataType , StructType }
4142
4243private [sql] object PythonSQLUtils extends Logging {
43- private lazy val internalRowPickler = {
44+ private def withInternalRowPickler ( f : Pickler => Array [ Byte ]) : Array [ Byte ] = {
4445 EvaluatePython .registerPicklers()
45- new Pickler (true , false )
46+ val pickler = new Pickler (true , false )
47+ val ret = try {
48+ f(pickler)
49+ } finally {
50+ pickler.close()
51+ }
52+ ret
53+ }
54+
55+ private def withInternalRowUnpickler (f : Unpickler => Any ): Any = {
56+ EvaluatePython .registerPicklers()
57+ val unpickler = new Unpickler
58+ val ret = try {
59+ f(unpickler)
60+ } finally {
61+ unpickler.close()
62+ }
63+ ret
4664 }
4765
4866 def parseDataType (typeText : String ): DataType = CatalystSqlParser .parseDataType(typeText)
@@ -94,8 +112,18 @@ private[sql] object PythonSQLUtils extends Logging {
94112
95113 def toPyRow (row : Row ): Array [Byte ] = {
96114 assert(row.isInstanceOf [GenericRowWithSchema ])
97- internalRowPickler.dumps(EvaluatePython .toJava(
98- CatalystTypeConverters .convertToCatalyst(row), row.schema))
115+ withInternalRowPickler(_.dumps(EvaluatePython .toJava(
116+ CatalystTypeConverters .convertToCatalyst(row), row.schema)))
117+ }
118+
119+ def toJVMRow (
120+ arr : Array [Byte ],
121+ returnType : StructType ,
122+ deserializer : ExpressionEncoder .Deserializer [Row ]): Row = {
123+ val fromJava = EvaluatePython .makeFromJava(returnType)
124+ val internalRow =
125+ fromJava(withInternalRowUnpickler(_.loads(arr))).asInstanceOf [InternalRow ]
126+ deserializer(internalRow)
99127 }
100128
101129 def castTimestampNTZToLong (c : Column ): Column = Column (CastTimestampNTZToLong (c.expr))
0 commit comments