@@ -246,7 +246,10 @@ def visit_type_var(self, t: TypeVarType) -> Type:
246246 return repl
247247
248248 def visit_param_spec (self , t : ParamSpecType ) -> Type :
249- repl = get_proper_type (self .variables .get (t .id , t ))
249+ # set prefix to something empty so we don't duplicate it
250+ repl = get_proper_type (
251+ self .variables .get (t .id , t .copy_modified (prefix = Parameters ([], [], [])))
252+ )
250253 if isinstance (repl , Instance ):
251254 # TODO: what does prefix mean in this case?
252255 # TODO: why does this case even happen? Instances aren't plural.
@@ -369,7 +372,7 @@ def visit_callable_type(self, t: CallableType) -> CallableType:
369372 # must expand both of them with all the argument types,
370373 # kinds and names in the replacement. The return type in
371374 # the replacement is ignored.
372- if isinstance (repl , CallableType ) or isinstance ( repl , Parameters ):
375+ if isinstance (repl , ( CallableType , Parameters ) ):
373376 # Substitute *args: P.args, **kwargs: P.kwargs
374377 prefix = param_spec .prefix
375378 # we need to expand the types in the prefix, so might as well
@@ -382,6 +385,23 @@ def visit_callable_type(self, t: CallableType) -> CallableType:
382385 ret_type = t .ret_type .accept (self ),
383386 type_guard = (t .type_guard .accept (self ) if t .type_guard is not None else None ),
384387 )
388+ # TODO: Conceptually, the "len(t.arg_types) == 2" should not be here. However, this
389+ # errors without it. Either figure out how to eliminate this or place an
390+ # explanation for why this is necessary.
391+ elif isinstance (repl , ParamSpecType ) and len (t .arg_types ) == 2 :
392+ # We're substituting one paramspec for another; this can mean that the prefix
393+ # changes. (e.g. sub Concatenate[int, P] for Q)
394+ prefix = repl .prefix
395+ old_prefix = param_spec .prefix
396+
397+ # Check assumptions. I'm not sure what order to place new prefix vs old prefix:
398+ assert not old_prefix .arg_types or not prefix .arg_types
399+
400+ t = t .copy_modified (
401+ arg_types = prefix .arg_types + old_prefix .arg_types + t .arg_types ,
402+ arg_kinds = prefix .arg_kinds + old_prefix .arg_kinds + t .arg_kinds ,
403+ arg_names = prefix .arg_names + old_prefix .arg_names + t .arg_names ,
404+ )
385405
386406 var_arg = t .var_arg ()
387407 if var_arg is not None and isinstance (var_arg .typ , UnpackType ):
0 commit comments