Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 28 additions & 7 deletions Lib/functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,6 +837,14 @@ def dispatch(cls):
dispatch_cache[cls] = impl
return impl

def _is_union_type(cls):
from typing import get_origin, Union
return get_origin(cls) in {Union, types.UnionType}

def _is_valid_union_type(cls):
from typing import get_args
return _is_union_type(cls) and all(isinstance(arg, type) for arg in get_args(cls))

def register(cls, func=None):
"""generic_func.register(cls, func) -> func

Expand All @@ -845,7 +853,7 @@ def register(cls, func=None):
"""
nonlocal cache_token
if func is None:
if isinstance(cls, type):
if isinstance(cls, type) or _is_valid_union_type(cls):
return lambda f: register(cls, f)
ann = getattr(cls, '__annotations__', {})
if not ann:
Expand All @@ -859,12 +867,25 @@ def register(cls, func=None):
# only import typing if annotation parsing is necessary
from typing import get_type_hints
argname, cls = next(iter(get_type_hints(func).items()))
if not isinstance(cls, type):
raise TypeError(
f"Invalid annotation for {argname!r}. "
f"{cls!r} is not a class."
)
registry[cls] = func
if not isinstance(cls, type) and not _is_valid_union_type(cls):
if _is_union_type(cls):
raise TypeError(
f"Invalid annotation for {argname!r}. "
f"{cls!r} not all arguments are classes."
)
else:
raise TypeError(
f"Invalid annotation for {argname!r}. "
f"{cls!r} is not a class."
)

if _is_union_type(cls):
from typing import get_args

for arg in get_args(cls):
registry[arg] = func
else:
registry[cls] = func
if cache_token is None and hasattr(cls, '__abstractmethods__'):
cache_token = get_cache_token()
dispatch_cache.clear()
Expand Down
30 changes: 30 additions & 0 deletions Lib/test/test_functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2684,6 +2684,17 @@ def _(arg: typing.Iterable[str]):
'typing.Iterable[str] is not a class.'
))

with self.assertRaises(TypeError) as exc:
@i.register
def _(arg: typing.Union[int, typing.Iterable[str]]):
return "Invalid Union"
self.assertTrue(str(exc.exception).startswith(
"Invalid annotation for 'arg'."
))
self.assertTrue(str(exc.exception).endswith(
'typing.Union[int, typing.Iterable[str]] not all arguments are classes.'
))

def test_invalid_positional_argument(self):
@functools.singledispatch
def f(*args):
Expand All @@ -2692,6 +2703,25 @@ def f(*args):
with self.assertRaisesRegex(TypeError, msg):
f()

def test_union(self):
@functools.singledispatch
def f(arg):
return "default"

@f.register
def _(arg: typing.Union[str, bytes]):
return "typing.Union"

@f.register
def _(arg: int | float):
return "types.UnionType"

self.assertEqual(f([]), "default")
self.assertEqual(f(""), "typing.Union")
self.assertEqual(f(b""), "typing.Union")
self.assertEqual(f(1), "types.UnionType")
self.assertEqual(f(1.0), "types.UnionType")


class CachedCostItem:
_cost = 1
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Add ability to use ``typing.Union`` and ``types.UnionType`` as dispatch
argument to ``functools.singledispatch``. Patch provided by Yurii Karabas.