Skip to content
56 changes: 56 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
import functools
import time
import datetime
import array
import math

import py4j
try:
Expand Down Expand Up @@ -1735,6 +1737,60 @@ def test_BinaryType_serialization(self):
df = self.spark.createDataFrame(data, schema=schema)
df.collect()

# test for SPARK-16542
def test_array_types(self):
int_types = set(['b', 'h', 'i', 'l'])
float_types = set(['f', 'd'])
unsupported_types = set(array.typecodes) - int_types - float_types

def collected(a):
row = Row(myarray=a)
rdd = self.sc.parallelize([row])
df = self.spark.createDataFrame(rdd)
return df.collect()[0]["myarray"][0]
# test whether pyspark can correctly handle int types
for t in int_types:
# test positive numbers
a = array.array(t, [1])
while True:
try:
self.assertEqual(collected(a), a[0])
a[0] *= 2
except OverflowError:
break
# test negative numbers
a = array.array(t, [-1])
while True:
try:
self.assertEqual(collected(a), a[0])
a[0] *= 2
except OverflowError:
break
# test whether pyspark can correctly handle float types
for t in float_types:
# test upper bound and precision
a = array.array(t, [1.0])
while not math.isinf(a[0]):
self.assertEqual(collected(a), a[0])
a[0] *= 2
a[0] += 1
# test lower bound
a = array.array(t, [1.0])
while a[0] != 0:
self.assertEqual(collected(a), a[0])
a[0] /= 2
# test whether pyspark can correctly handle unsupported types
for t in unsupported_types:
try:
a = array.array(t)
c = collected(a)
self.assertTrue(False) # if no exception thrown, fail the test
except TypeError:
pass # catch the expected exception and do nothing
except:
# if incorrect exception thrown, fail the test
self.assertTrue(False)


class HiveSparkSubmitTests(SparkSubmitTests):

Expand Down
17 changes: 16 additions & 1 deletion python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,6 +929,16 @@ def _parse_datatype_json_value(json_value):
datetime.time: TimestampType,
}

# Mapping Python array types to Spark SQL DataType
_array_type_mappings = {
'b': ByteType,
'h': ShortType,
'i': IntegerType,
'l': LongType,
'f': FloatType,
'd': DoubleType
}

if sys.version < "3":
_type_mappings.update({
unicode: StringType,
Expand Down Expand Up @@ -958,12 +968,17 @@ def _infer_type(obj):
return MapType(_infer_type(key), _infer_type(value), True)
else:
return MapType(NullType(), NullType(), True)
elif isinstance(obj, (list, array)):
elif isinstance(obj, list):
for v in obj:
if v is not None:
return ArrayType(_infer_type(obj[0]), True)
else:
return ArrayType(NullType(), True)
elif isinstance(obj, array):
if obj.typecode in _array_type_mappings:
return ArrayType(_array_type_mappings[obj.typecode](), True)
else:
raise TypeError("not supported type: array(%s)" % obj.typecode)
else:
try:
return _infer_schema(obj)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,20 +91,30 @@ object EvaluatePython {

case (c: Boolean, BooleanType) => c

case (c: Byte, ByteType) => c
case (c: Short, ByteType) => c.toByte
case (c: Int, ByteType) => c.toByte
case (c: Long, ByteType) => c.toByte

case (c: Byte, ShortType) => c.toShort
case (c: Short, ShortType) => c
case (c: Int, ShortType) => c.toShort
case (c: Long, ShortType) => c.toShort

case (c: Byte, IntegerType) => c.toInt
case (c: Short, IntegerType) => c.toInt
case (c: Int, IntegerType) => c
case (c: Long, IntegerType) => c.toInt

case (c: Byte, LongType) => c.toLong
case (c: Short, LongType) => c.toLong
case (c: Int, LongType) => c.toLong
case (c: Long, LongType) => c

case (c: Float, FloatType) => c
case (c: Double, FloatType) => c.toFloat

case (c: Float, DoubleType) => c.toDouble
case (c: Double, DoubleType) => c

case (c: java.math.BigDecimal, dt: DecimalType) => Decimal(c, dt.precision, dt.scale)
Expand Down