11import threading
2- from typing import Any , Callable , Generic , Optional , TypeVar
2+ from typing import Any , Callable , Optional , TypeVar
3+
4+ from typing_extensions import Concatenate , Generic , ParamSpec
35
46from basilisp .lang import map as lmap
57from basilisp .lang import runtime
810from basilisp .lang .set import PersistentSet
911
1012T = TypeVar ("T" )
11- DispatchFunction = Callable [..., T ]
12- Method = Callable [..., Any ]
13+ P = ParamSpec ("P" )
14+ DispatchFunction = Callable [Concatenate [T , P ], T ]
15+ Method = Callable [Concatenate [T , P ], Any ]
1316
1417
1518_GLOBAL_HIERARCHY_SYM = sym .symbol ("global-hierarchy" , ns = runtime .CORE_NS )
1619_ISA_SYM = sym .symbol ("isa?" , ns = runtime .CORE_NS )
1720
1821
19- class MultiFunction (Generic [T ]):
22+ class MultiFunction (Generic [T , P ]):
2023 __slots__ = (
2124 "_name" ,
2225 "_default" ,
@@ -33,7 +36,7 @@ class MultiFunction(Generic[T]):
3336 def __init__ (
3437 self ,
3538 name : sym .Symbol ,
36- dispatch : DispatchFunction ,
39+ dispatch : DispatchFunction [ T , P ] ,
3740 default : T ,
3841 hierarchy : Optional [IRef ] = None ,
3942 ) -> None :
@@ -63,11 +66,11 @@ def __init__(
6366 # caches.
6467 self ._cached_hierarchy = self ._hierarchy .deref ()
6568
66- def __call__ (self , * args , ** kwargs ) :
67- key = self ._dispatch (* args , ** kwargs )
69+ def __call__ (self , v : T , * args : P . args , ** kwargs : P . kwargs ) -> Any :
70+ key = self ._dispatch (v , * args , ** kwargs )
6871 method = self .get_method (key )
6972 if method is not None :
70- return method (* args , ** kwargs )
73+ return method (v , * args , ** kwargs )
7174 raise NotImplementedError
7275
7376 def _reset_cache (self ):
@@ -94,14 +97,14 @@ def _precedes(self, tag: T, parent: T) -> bool:
9497 selection."""
9598 return self ._has_preference (tag , parent ) or self ._is_a (tag , parent )
9699
97- def add_method (self , key : T , method : Method ) -> None :
100+ def add_method (self , key : T , method : Method [ T , P ] ) -> None :
98101 """Add a new method to this function which will respond for key returned from
99102 the dispatch function."""
100103 with self ._lock :
101104 self ._methods = self ._methods .assoc (key , method )
102105 self ._reset_cache ()
103106
104- def _find_and_cache_method (self , key : T ) -> Optional [Method ]:
107+ def _find_and_cache_method (self , key : T ) -> Optional [Method [ T , P ] ]:
105108 """Find and cache the best method for dispatch value `key`."""
106109 with self ._lock :
107110 best_key : Optional [T ] = None
@@ -125,7 +128,7 @@ def _find_and_cache_method(self, key: T) -> Optional[Method]:
125128
126129 return best_method
127130
128- def get_method (self , key : T ) -> Optional [Method ]:
131+ def get_method (self , key : T ) -> Optional [Method [ T , P ] ]:
129132 """Return the method which would handle this dispatch key or None if no method
130133 defined for this key and no default."""
131134 if self ._cached_hierarchy != self ._hierarchy .deref ():
@@ -159,7 +162,7 @@ def prefers(self):
159162 """Return a mapping of preferred values to the set of other values."""
160163 return self ._prefers
161164
162- def remove_method (self , key : T ) -> Optional [Method ]:
165+ def remove_method (self , key : T ) -> Optional [Method [ T , P ] ]:
163166 """Remove the method defined for this key and return it."""
164167 with self ._lock :
165168 method = self ._methods .val_at (key , None )
@@ -179,5 +182,5 @@ def default(self) -> T:
179182 return self ._default
180183
181184 @property
182- def methods (self ) -> IPersistentMap [T , Method ]:
185+ def methods (self ) -> IPersistentMap [T , Method [ T , P ] ]:
183186 return self ._methods
0 commit comments