|
2 | 2 |
|
3 | 3 | from __future__ import annotations
|
4 | 4 |
|
5 |
| -from typing import Iterable, List, Optional, cast |
| 5 | +from typing import Iterable, List, cast |
6 | 6 | from typing_extensions import Final, Literal
|
7 | 7 |
|
8 | 8 | import mypy.plugin # To avoid circular imports.
|
|
43 | 43 | Var,
|
44 | 44 | is_class_var,
|
45 | 45 | )
|
46 |
| -from mypy.plugin import FunctionContext, SemanticAnalyzerPluginInterface |
| 46 | +from mypy.plugin import SemanticAnalyzerPluginInterface |
47 | 47 | from mypy.plugins.common import (
|
48 | 48 | _get_argument,
|
49 | 49 | _get_bool_argument,
|
@@ -990,27 +990,42 @@ def evolve_function_sig_callback(ctx: mypy.plugin.FunctionSigContext) -> Callabl
|
990 | 990 | )
|
991 | 991 |
|
992 | 992 |
|
993 |
| -def _get_cls_from_init(t: Type) -> Optional[TypeInfo]: |
994 |
| - proper_type = get_proper_type(t) |
995 |
| - if isinstance(proper_type, CallableType): |
996 |
| - return proper_type.type_object() |
997 |
| - return None |
| 993 | +def fields_function_sig_callback(ctx: mypy.plugin.FunctionSigContext) -> CallableType: |
| 994 | + """Provide the proper signature for `attrs.fields`.""" |
| 995 | + if ctx.args and len(ctx.args) == 1 and ctx.args[0] and ctx.args[0][0]: |
998 | 996 |
|
| 997 | + # <hack> |
| 998 | + assert isinstance(ctx.api, TypeChecker) |
| 999 | + inst_type = ctx.api.expr_checker.accept(ctx.args[0][0]) |
| 1000 | + # </hack> |
| 1001 | + proper_type = get_proper_type(inst_type) |
| 1002 | + |
| 1003 | + if isinstance(proper_type, AnyType): # fields(Any) -> Any |
| 1004 | + return ctx.default_signature |
| 1005 | + |
| 1006 | + cls = None |
| 1007 | + arg_types = ctx.default_signature.arg_types |
| 1008 | + |
| 1009 | + if isinstance(proper_type, TypeVarType): |
| 1010 | + inner = get_proper_type(proper_type.upper_bound) |
| 1011 | + if isinstance(inner, Instance): |
| 1012 | + # We need to work arg_types to compensate for the attrs stubs. |
| 1013 | + arg_types = [inst_type] |
| 1014 | + cls = inner.type |
| 1015 | + elif isinstance(proper_type, CallableType): |
| 1016 | + cls = proper_type.type_object() |
999 | 1017 |
|
1000 |
| -def fields_function_callback(ctx: FunctionContext) -> Type: |
1001 |
| - """Provide the proper return value for `attrs.fields`.""" |
1002 |
| - if ctx.arg_types and ctx.arg_types[0] and ctx.arg_types[0][0]: |
1003 |
| - first_arg_type = ctx.arg_types[0][0] |
1004 |
| - cls = _get_cls_from_init(first_arg_type) |
1005 | 1018 | if cls is not None:
|
1006 | 1019 | if MAGIC_ATTR_NAME in cls.names:
|
1007 | 1020 | # This is a proper attrs class.
|
1008 | 1021 | ret_type = cls.names[MAGIC_ATTR_NAME].type
|
1009 | 1022 | if ret_type is not None:
|
1010 |
| - return ret_type |
1011 |
| - else: |
1012 |
| - ctx.api.fail( |
1013 |
| - f'Argument 1 to "fields" has incompatible type "{format_type_bare(first_arg_type)}"; expected an attrs class', |
1014 |
| - ctx.context, |
1015 |
| - ) |
1016 |
| - return ctx.default_return_type |
| 1023 | + return ctx.default_signature.copy_modified( |
| 1024 | + arg_types=arg_types, ret_type=ret_type |
| 1025 | + ) |
| 1026 | + |
| 1027 | + ctx.api.fail( |
| 1028 | + f'Argument 1 to "fields" has incompatible type "{format_type_bare(proper_type)}"; expected an attrs class', |
| 1029 | + ctx.context, |
| 1030 | + ) |
| 1031 | + return ctx.default_signature |
0 commit comments