Skip to content

Commit b7a40f6

Browse files
zasdfgbnmueshin
authored andcommitted
[SPARK-16542][SQL][PYSPARK] Fix bugs about types that result an array of null when creating DataFrame using python
## What changes were proposed in this pull request? This is the reopen of #14198, with merge conflicts resolved. ueshin Could you please take a look at my code? Fix bugs about types that result an array of null when creating DataFrame using python. Python's array.array have richer type than python itself, e.g. we can have `array('f',[1,2,3])` and `array('d',[1,2,3])`. Codes in spark-sql and pyspark didn't take this into consideration which might cause a problem that you get an array of null values when you have `array('f')` in your rows. A simple code to reproduce this bug is: ``` from pyspark import SparkContext from pyspark.sql import SQLContext,Row,DataFrame from array import array sc = SparkContext() sqlContext = SQLContext(sc) row1 = Row(floatarray=array('f',[1,2,3]), doublearray=array('d',[1,2,3])) rows = sc.parallelize([ row1 ]) df = sqlContext.createDataFrame(rows) df.show() ``` which have output ``` +---------------+------------------+ | doublearray| floatarray| +---------------+------------------+ |[1.0, 2.0, 3.0]|[null, null, null]| +---------------+------------------+ ``` ## How was this patch tested? New test case added Author: Xiang Gao <[email protected]> Author: Gao, Xiang <[email protected]> Author: Takuya UESHIN <[email protected]> Closes #18444 from zasdfgbnm/fix_array_infer.
1 parent 2c9d5ef commit b7a40f6

File tree

4 files changed

+216
-6
lines changed

4 files changed

+216
-6
lines changed

core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,12 @@ private[spark] object SerDeUtil extends Logging {
5555
// {'d', sizeof(double), d_getitem, d_setitem},
5656
// {'\0', 0, 0, 0} /* Sentinel */
5757
// };
58-
// TODO: support Py_UNICODE with 2 bytes
5958
val machineCodes: Map[Char, Int] = if (ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN)) {
60-
Map('c' -> 1, 'B' -> 0, 'b' -> 1, 'H' -> 3, 'h' -> 5, 'I' -> 7, 'i' -> 9,
59+
Map('B' -> 0, 'b' -> 1, 'H' -> 3, 'h' -> 5, 'I' -> 7, 'i' -> 9,
6160
'L' -> 11, 'l' -> 13, 'f' -> 15, 'd' -> 17, 'u' -> 21
6261
)
6362
} else {
64-
Map('c' -> 1, 'B' -> 0, 'b' -> 1, 'H' -> 2, 'h' -> 4, 'I' -> 6, 'i' -> 8,
63+
Map('B' -> 0, 'b' -> 1, 'H' -> 2, 'h' -> 4, 'I' -> 6, 'i' -> 8,
6564
'L' -> 10, 'l' -> 12, 'f' -> 14, 'd' -> 16, 'u' -> 20
6665
)
6766
}
@@ -72,7 +71,20 @@ private[spark] object SerDeUtil extends Logging {
7271
val typecode = args(0).asInstanceOf[String].charAt(0)
7372
// This must be ISO 8859-1 / Latin 1, not UTF-8, to interoperate correctly
7473
val data = args(1).asInstanceOf[String].getBytes(StandardCharsets.ISO_8859_1)
75-
construct(typecode, machineCodes(typecode), data)
74+
if (typecode == 'c') {
75+
// It seems like the pickle of pypy uses the similar protocol to Python 2.6, which uses
76+
// a string for array data instead of list as Python 2.7, and handles an array of
77+
// typecode 'c' as 1-byte character.
78+
val result = new Array[Char](data.length)
79+
var i = 0
80+
while (i < data.length) {
81+
result(i) = data(i).toChar
82+
i += 1
83+
}
84+
result
85+
} else {
86+
construct(typecode, machineCodes(typecode), data)
87+
}
7688
} else if (args.length == 2 && args(0) == "l") {
7789
// On Python 2, an array of typecode 'l' should be handled as long rather than int.
7890
val values = args(1).asInstanceOf[JArrayList[_]]

python/pyspark/sql/tests.py

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@
3030
import functools
3131
import time
3232
import datetime
33-
33+
import array
34+
import ctypes
3435
import py4j
36+
3537
try:
3638
import xmlrunner
3739
except ImportError:
@@ -58,6 +60,8 @@
5860
from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row
5961
from pyspark.sql.types import *
6062
from pyspark.sql.types import UserDefinedType, _infer_type, _make_type_verifier
63+
from pyspark.sql.types import _array_signed_int_typecode_ctype_mappings, _array_type_mappings
64+
from pyspark.sql.types import _array_unsigned_int_typecode_ctype_mappings
6165
from pyspark.tests import QuietTest, ReusedPySparkTestCase, SparkSubmitTests
6266
from pyspark.sql.functions import UserDefinedFunction, sha2, lit
6367
from pyspark.sql.window import Window
@@ -2333,6 +2337,97 @@ def test_BinaryType_serialization(self):
23332337
df = self.spark.createDataFrame(data, schema=schema)
23342338
df.collect()
23352339

2340+
# test for SPARK-16542
2341+
def test_array_types(self):
2342+
# This test need to make sure that the Scala type selected is at least
2343+
# as large as the python's types. This is necessary because python's
2344+
# array types depend on C implementation on the machine. Therefore there
2345+
# is no machine independent correspondence between python's array types
2346+
# and Scala types.
2347+
# See: https://docs.python.org/2/library/array.html
2348+
2349+
def assertCollectSuccess(typecode, value):
2350+
row = Row(myarray=array.array(typecode, [value]))
2351+
df = self.spark.createDataFrame([row])
2352+
self.assertEqual(df.first()["myarray"][0], value)
2353+
2354+
# supported string types
2355+
#
2356+
# String types in python's array are "u" for Py_UNICODE and "c" for char.
2357+
# "u" will be removed in python 4, and "c" is not supported in python 3.
2358+
supported_string_types = []
2359+
if sys.version_info[0] < 4:
2360+
supported_string_types += ['u']
2361+
# test unicode
2362+
assertCollectSuccess('u', u'a')
2363+
if sys.version_info[0] < 3:
2364+
supported_string_types += ['c']
2365+
# test string
2366+
assertCollectSuccess('c', 'a')
2367+
2368+
# supported float and double
2369+
#
2370+
# Test max, min, and precision for float and double, assuming IEEE 754
2371+
# floating-point format.
2372+
supported_fractional_types = ['f', 'd']
2373+
assertCollectSuccess('f', ctypes.c_float(1e+38).value)
2374+
assertCollectSuccess('f', ctypes.c_float(1e-38).value)
2375+
assertCollectSuccess('f', ctypes.c_float(1.123456).value)
2376+
assertCollectSuccess('d', sys.float_info.max)
2377+
assertCollectSuccess('d', sys.float_info.min)
2378+
assertCollectSuccess('d', sys.float_info.epsilon)
2379+
2380+
# supported signed int types
2381+
#
2382+
# The size of C types changes with implementation, we need to make sure
2383+
# that there is no overflow error on the platform running this test.
2384+
supported_signed_int_types = list(
2385+
set(_array_signed_int_typecode_ctype_mappings.keys())
2386+
.intersection(set(_array_type_mappings.keys())))
2387+
for t in supported_signed_int_types:
2388+
ctype = _array_signed_int_typecode_ctype_mappings[t]
2389+
max_val = 2 ** (ctypes.sizeof(ctype) * 8 - 1)
2390+
assertCollectSuccess(t, max_val - 1)
2391+
assertCollectSuccess(t, -max_val)
2392+
2393+
# supported unsigned int types
2394+
#
2395+
# JVM does not have unsigned types. We need to be very careful to make
2396+
# sure that there is no overflow error.
2397+
supported_unsigned_int_types = list(
2398+
set(_array_unsigned_int_typecode_ctype_mappings.keys())
2399+
.intersection(set(_array_type_mappings.keys())))
2400+
for t in supported_unsigned_int_types:
2401+
ctype = _array_unsigned_int_typecode_ctype_mappings[t]
2402+
assertCollectSuccess(t, 2 ** (ctypes.sizeof(ctype) * 8) - 1)
2403+
2404+
# all supported types
2405+
#
2406+
# Make sure the types tested above:
2407+
# 1. are all supported types
2408+
# 2. cover all supported types
2409+
supported_types = (supported_string_types +
2410+
supported_fractional_types +
2411+
supported_signed_int_types +
2412+
supported_unsigned_int_types)
2413+
self.assertEqual(set(supported_types), set(_array_type_mappings.keys()))
2414+
2415+
# all unsupported types
2416+
#
2417+
# Keys in _array_type_mappings is a complete list of all supported types,
2418+
# and types not in _array_type_mappings are considered unsupported.
2419+
# `array.typecodes` are not supported in python 2.
2420+
if sys.version_info[0] < 3:
2421+
all_types = set(['c', 'b', 'B', 'u', 'h', 'H', 'i', 'I', 'l', 'L', 'f', 'd'])
2422+
else:
2423+
all_types = set(array.typecodes)
2424+
unsupported_types = all_types - set(supported_types)
2425+
# test unsupported types
2426+
for t in unsupported_types:
2427+
with self.assertRaises(TypeError):
2428+
a = array.array(t)
2429+
self.spark.createDataFrame([Row(myarray=a)]).collect()
2430+
23362431
def test_bucketed_write(self):
23372432
data = [
23382433
(1, "foo", 3.0), (2, "foo", 5.0),

python/pyspark/sql/types.py

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import re
2525
import base64
2626
from array import array
27+
import ctypes
2728

2829
if sys.version >= "3":
2930
long = int
@@ -915,6 +916,93 @@ def _parse_datatype_json_value(json_value):
915916
long: LongType,
916917
})
917918

919+
# Mapping Python array types to Spark SQL DataType
920+
# We should be careful here. The size of these types in python depends on C
921+
# implementation. We need to make sure that this conversion does not lose any
922+
# precision. Also, JVM only support signed types, when converting unsigned types,
923+
# keep in mind that it required 1 more bit when stored as singed types.
924+
#
925+
# Reference for C integer size, see:
926+
# ISO/IEC 9899:201x specification, chapter 5.2.4.2.1 Sizes of integer types <limits.h>.
927+
# Reference for python array typecode, see:
928+
# https://docs.python.org/2/library/array.html
929+
# https://docs.python.org/3.6/library/array.html
930+
# Reference for JVM's supported integral types:
931+
# http://docs.oracle.com/javase/specs/jvms/se8/html/jvms-2.html#jvms-2.3.1
932+
933+
_array_signed_int_typecode_ctype_mappings = {
934+
'b': ctypes.c_byte,
935+
'h': ctypes.c_short,
936+
'i': ctypes.c_int,
937+
'l': ctypes.c_long,
938+
}
939+
940+
_array_unsigned_int_typecode_ctype_mappings = {
941+
'B': ctypes.c_ubyte,
942+
'H': ctypes.c_ushort,
943+
'I': ctypes.c_uint,
944+
'L': ctypes.c_ulong
945+
}
946+
947+
948+
def _int_size_to_type(size):
949+
"""
950+
Return the Catalyst datatype from the size of integers.
951+
"""
952+
if size <= 8:
953+
return ByteType
954+
if size <= 16:
955+
return ShortType
956+
if size <= 32:
957+
return IntegerType
958+
if size <= 64:
959+
return LongType
960+
961+
# The list of all supported array typecodes is stored here
962+
_array_type_mappings = {
963+
# Warning: Actual properties for float and double in C is not specified in C.
964+
# On almost every system supported by both python and JVM, they are IEEE 754
965+
# single-precision binary floating-point format and IEEE 754 double-precision
966+
# binary floating-point format. And we do assume the same thing here for now.
967+
'f': FloatType,
968+
'd': DoubleType
969+
}
970+
971+
# compute array typecode mappings for signed integer types
972+
for _typecode in _array_signed_int_typecode_ctype_mappings.keys():
973+
size = ctypes.sizeof(_array_signed_int_typecode_ctype_mappings[_typecode]) * 8
974+
dt = _int_size_to_type(size)
975+
if dt is not None:
976+
_array_type_mappings[_typecode] = dt
977+
978+
# compute array typecode mappings for unsigned integer types
979+
for _typecode in _array_unsigned_int_typecode_ctype_mappings.keys():
980+
# JVM does not have unsigned types, so use signed types that is at least 1
981+
# bit larger to store
982+
size = ctypes.sizeof(_array_unsigned_int_typecode_ctype_mappings[_typecode]) * 8 + 1
983+
dt = _int_size_to_type(size)
984+
if dt is not None:
985+
_array_type_mappings[_typecode] = dt
986+
987+
# Type code 'u' in Python's array is deprecated since version 3.3, and will be
988+
# removed in version 4.0. See: https://docs.python.org/3/library/array.html
989+
if sys.version_info[0] < 4:
990+
_array_type_mappings['u'] = StringType
991+
992+
# Type code 'c' are only available at python 2
993+
if sys.version_info[0] < 3:
994+
_array_type_mappings['c'] = StringType
995+
996+
# SPARK-21465:
997+
# In python2, array of 'L' happened to be mistakenly partially supported. To
998+
# avoid breaking user's code, we should keep this partial support. Below is a
999+
# dirty hacking to keep this partial support and make the unit test passes
1000+
import platform
1001+
if sys.version_info[0] < 3 and platform.python_implementation() != 'PyPy':
1002+
if 'L' not in _array_type_mappings.keys():
1003+
_array_type_mappings['L'] = LongType
1004+
_array_unsigned_int_typecode_ctype_mappings['L'] = ctypes.c_uint
1005+
9181006

9191007
def _infer_type(obj):
9201008
"""Infer the DataType from obj
@@ -938,12 +1026,17 @@ def _infer_type(obj):
9381026
return MapType(_infer_type(key), _infer_type(value), True)
9391027
else:
9401028
return MapType(NullType(), NullType(), True)
941-
elif isinstance(obj, (list, array)):
1029+
elif isinstance(obj, list):
9421030
for v in obj:
9431031
if v is not None:
9441032
return ArrayType(_infer_type(obj[0]), True)
9451033
else:
9461034
return ArrayType(NullType(), True)
1035+
elif isinstance(obj, array):
1036+
if obj.typecode in _array_type_mappings:
1037+
return ArrayType(_array_type_mappings[obj.typecode](), False)
1038+
else:
1039+
raise TypeError("not supported type: array(%s)" % obj.typecode)
9471040
else:
9481041
try:
9491042
return _infer_schema(obj)

sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,20 +91,30 @@ object EvaluatePython {
9191

9292
case (c: Boolean, BooleanType) => c
9393

94+
case (c: Byte, ByteType) => c
95+
case (c: Short, ByteType) => c.toByte
9496
case (c: Int, ByteType) => c.toByte
9597
case (c: Long, ByteType) => c.toByte
9698

99+
case (c: Byte, ShortType) => c.toShort
100+
case (c: Short, ShortType) => c
97101
case (c: Int, ShortType) => c.toShort
98102
case (c: Long, ShortType) => c.toShort
99103

104+
case (c: Byte, IntegerType) => c.toInt
105+
case (c: Short, IntegerType) => c.toInt
100106
case (c: Int, IntegerType) => c
101107
case (c: Long, IntegerType) => c.toInt
102108

109+
case (c: Byte, LongType) => c.toLong
110+
case (c: Short, LongType) => c.toLong
103111
case (c: Int, LongType) => c.toLong
104112
case (c: Long, LongType) => c
105113

114+
case (c: Float, FloatType) => c
106115
case (c: Double, FloatType) => c.toFloat
107116

117+
case (c: Float, DoubleType) => c.toDouble
108118
case (c: Double, DoubleType) => c
109119

110120
case (c: java.math.BigDecimal, dt: DecimalType) => Decimal(c, dt.precision, dt.scale)

0 commit comments

Comments
 (0)