|
4 | 4 |
|
5 | 5 | from typing import Final, NamedTuple |
6 | 6 |
|
| 7 | +import mypy.checker |
7 | 8 | import mypy.plugin |
8 | | -from mypy.nodes import ARG_POS, ARG_STAR2, Argument, FuncItem, Var |
| 9 | +from mypy.argmap import map_actuals_to_formals |
| 10 | +from mypy.nodes import ARG_POS, ARG_STAR2, ArgKind, Argument, FuncItem, Var |
9 | 11 | from mypy.plugins.common import add_method_to_class |
10 | | -from mypy.types import AnyType, CallableType, Type, TypeOfAny, UnboundType, get_proper_type |
| 12 | +from mypy.types import ( |
| 13 | + AnyType, |
| 14 | + CallableType, |
| 15 | + Instance, |
| 16 | + Overloaded, |
| 17 | + Type, |
| 18 | + TypeOfAny, |
| 19 | + UnboundType, |
| 20 | + UninhabitedType, |
| 21 | + get_proper_type, |
| 22 | +) |
11 | 23 |
|
12 | 24 | functools_total_ordering_makers: Final = {"functools.total_ordering"} |
13 | 25 |
|
@@ -102,3 +114,131 @@ def _analyze_class(ctx: mypy.plugin.ClassDefContext) -> dict[str, _MethodInfo | |
102 | 114 | comparison_methods[name] = None |
103 | 115 |
|
104 | 116 | return comparison_methods |
| 117 | + |
| 118 | + |
| 119 | +def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type: |
| 120 | + """Infer a more precise return type for functools.partial""" |
| 121 | + if not isinstance(ctx.api, mypy.checker.TypeChecker): # use internals |
| 122 | + return ctx.default_return_type |
| 123 | + if len(ctx.arg_types) != 3: # fn, *args, **kwargs |
| 124 | + return ctx.default_return_type |
| 125 | + if len(ctx.arg_types[0]) != 1: |
| 126 | + return ctx.default_return_type |
| 127 | + |
| 128 | + if isinstance(get_proper_type(ctx.arg_types[0][0]), Overloaded): |
| 129 | + # TODO: handle overloads, just fall back to whatever the non-plugin code does |
| 130 | + return ctx.default_return_type |
| 131 | + fn_type = ctx.api.extract_callable_type(ctx.arg_types[0][0], ctx=ctx.default_return_type) |
| 132 | + if fn_type is None: |
| 133 | + return ctx.default_return_type |
| 134 | + |
| 135 | + defaulted = fn_type.copy_modified( |
| 136 | + arg_kinds=[ |
| 137 | + ( |
| 138 | + ArgKind.ARG_OPT |
| 139 | + if k == ArgKind.ARG_POS |
| 140 | + else (ArgKind.ARG_NAMED_OPT if k == ArgKind.ARG_NAMED else k) |
| 141 | + ) |
| 142 | + for k in fn_type.arg_kinds |
| 143 | + ] |
| 144 | + ) |
| 145 | + if defaulted.line < 0: |
| 146 | + # Make up a line number if we don't have one |
| 147 | + defaulted.set_line(ctx.default_return_type) |
| 148 | + |
| 149 | + actual_args = [a for param in ctx.args[1:] for a in param] |
| 150 | + actual_arg_kinds = [a for param in ctx.arg_kinds[1:] for a in param] |
| 151 | + actual_arg_names = [a for param in ctx.arg_names[1:] for a in param] |
| 152 | + actual_types = [a for param in ctx.arg_types[1:] for a in param] |
| 153 | + |
| 154 | + _, bound = ctx.api.expr_checker.check_call( |
| 155 | + callee=defaulted, |
| 156 | + args=actual_args, |
| 157 | + arg_kinds=actual_arg_kinds, |
| 158 | + arg_names=actual_arg_names, |
| 159 | + context=defaulted, |
| 160 | + ) |
| 161 | + bound = get_proper_type(bound) |
| 162 | + if not isinstance(bound, CallableType): |
| 163 | + return ctx.default_return_type |
| 164 | + |
| 165 | + formal_to_actual = map_actuals_to_formals( |
| 166 | + actual_kinds=actual_arg_kinds, |
| 167 | + actual_names=actual_arg_names, |
| 168 | + formal_kinds=fn_type.arg_kinds, |
| 169 | + formal_names=fn_type.arg_names, |
| 170 | + actual_arg_type=lambda i: actual_types[i], |
| 171 | + ) |
| 172 | + |
| 173 | + partial_kinds = [] |
| 174 | + partial_types = [] |
| 175 | + partial_names = [] |
| 176 | + # We need to fully apply any positional arguments (they cannot be respecified) |
| 177 | + # However, keyword arguments can be respecified, so just give them a default |
| 178 | + for i, actuals in enumerate(formal_to_actual): |
| 179 | + if len(bound.arg_types) == len(fn_type.arg_types): |
| 180 | + arg_type = bound.arg_types[i] |
| 181 | + if isinstance(get_proper_type(arg_type), UninhabitedType): |
| 182 | + arg_type = fn_type.arg_types[i] # bit of a hack |
| 183 | + else: |
| 184 | + # TODO: I assume that bound and fn_type have the same arguments. It appears this isn't |
| 185 | + # true when PEP 646 things are happening. See testFunctoolsPartialTypeVarTuple |
| 186 | + arg_type = fn_type.arg_types[i] |
| 187 | + |
| 188 | + if not actuals or fn_type.arg_kinds[i] in (ArgKind.ARG_STAR, ArgKind.ARG_STAR2): |
| 189 | + partial_kinds.append(fn_type.arg_kinds[i]) |
| 190 | + partial_types.append(arg_type) |
| 191 | + partial_names.append(fn_type.arg_names[i]) |
| 192 | + elif actuals: |
| 193 | + if any(actual_arg_kinds[j] == ArgKind.ARG_POS for j in actuals): |
| 194 | + continue |
| 195 | + kind = actual_arg_kinds[actuals[0]] |
| 196 | + if kind == ArgKind.ARG_NAMED: |
| 197 | + kind = ArgKind.ARG_NAMED_OPT |
| 198 | + partial_kinds.append(kind) |
| 199 | + partial_types.append(arg_type) |
| 200 | + partial_names.append(fn_type.arg_names[i]) |
| 201 | + |
| 202 | + ret_type = bound.ret_type |
| 203 | + if isinstance(get_proper_type(ret_type), UninhabitedType): |
| 204 | + ret_type = fn_type.ret_type # same kind of hack as above |
| 205 | + |
| 206 | + partially_applied = fn_type.copy_modified( |
| 207 | + arg_types=partial_types, |
| 208 | + arg_kinds=partial_kinds, |
| 209 | + arg_names=partial_names, |
| 210 | + ret_type=ret_type, |
| 211 | + ) |
| 212 | + |
| 213 | + ret = ctx.api.named_generic_type("functools.partial", [ret_type]) |
| 214 | + ret = ret.copy_with_extra_attr("__mypy_partial", partially_applied) |
| 215 | + return ret |
| 216 | + |
| 217 | + |
| 218 | +def partial_call_callback(ctx: mypy.plugin.MethodContext) -> Type: |
| 219 | + """Infer a more precise return type for functools.partial.__call__.""" |
| 220 | + if ( |
| 221 | + not isinstance(ctx.api, mypy.checker.TypeChecker) # use internals |
| 222 | + or not isinstance(ctx.type, Instance) |
| 223 | + or ctx.type.type.fullname != "functools.partial" |
| 224 | + or not ctx.type.extra_attrs |
| 225 | + or "__mypy_partial" not in ctx.type.extra_attrs.attrs |
| 226 | + ): |
| 227 | + return ctx.default_return_type |
| 228 | + |
| 229 | + partial_type = ctx.type.extra_attrs.attrs["__mypy_partial"] |
| 230 | + if len(ctx.arg_types) != 2: # *args, **kwargs |
| 231 | + return ctx.default_return_type |
| 232 | + |
| 233 | + args = [a for param in ctx.args for a in param] |
| 234 | + arg_kinds = [a for param in ctx.arg_kinds for a in param] |
| 235 | + arg_names = [a for param in ctx.arg_names for a in param] |
| 236 | + |
| 237 | + result = ctx.api.expr_checker.check_call( |
| 238 | + callee=partial_type, |
| 239 | + args=args, |
| 240 | + arg_kinds=arg_kinds, |
| 241 | + arg_names=arg_names, |
| 242 | + context=ctx.context, |
| 243 | + ) |
| 244 | + return result[0] |
0 commit comments