Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4381,6 +4381,24 @@ def test_timestamp_dst(self):
result = df.withColumn('time', foo_udf(df.time))
self.assertEquals(df.collect(), result.collect())

@unittest.skipIf(sys.version_info[:2] < (3, 5), "Type hints are supported from Python 3.5.")
def test_type_annotation(self):
from pyspark.sql.functions import pandas_udf
# Regression test to check if type hints can be used. See SPARK-23569.
# Note that it throws an error during compilation in lower Python versions if 'exec'
# is not used. Also, note that we explicitly use another dictionary to avoid modifications
# in the current 'locals()'.
#
# Hyukjin: I think it's an ugly way to test issues about syntax specific in
# higher versions of Python, which we shouldn't encourage. This was the last resort
# I could come up with at that time.
_locals = {}
exec(
"import pandas as pd\ndef noop(col: pd.Series) -> pd.Series: return col",
_locals)
df = self.spark.range(1).select(pandas_udf(f=_locals['noop'], returnType='bigint')('id'))
self.assertEqual(df.first()[0], 0)


@unittest.skipIf(
not _have_pandas or not _have_pyarrow,
Expand Down
9 changes: 8 additions & 1 deletion python/pyspark/sql/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,17 @@ def _create_udf(f, returnType, evalType):
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF):

import inspect
import sys
from pyspark.sql.utils import require_minimum_pyarrow_version

require_minimum_pyarrow_version()
argspec = inspect.getargspec(f)

if sys.version_info[0] < 3:
# `getargspec` is deprecated since python3.0 (incompatible with function annotations).
# See SPARK-23569.
argspec = inspect.getargspec(f)
else:
argspec = inspect.getfullargspec(f)

if evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF and len(argspec.args) == 0 and \
argspec.varargs is None:
Expand Down