1818 AssignmentStmt ,
1919 CallExpr ,
2020 Context ,
21+ DataclassTransformSpec ,
2122 Expression ,
2223 JsonDict ,
2324 NameExpr ,
25+ Node ,
2426 PlaceholderNode ,
2527 RefExpr ,
2628 SymbolTableNode ,
3739 add_method ,
3840 deserialize_and_fixup_type ,
3941)
42+ from mypy .semanal_shared import find_dataclass_transform_spec
4043from mypy .server .trigger import make_wildcard_trigger
4144from mypy .state import state
4245from mypy .typeops import map_type_from_supertype
5659
5760# The set of decorators that generate dataclasses.
5861dataclass_makers : Final = {"dataclass" , "dataclasses.dataclass" }
59- # The set of functions that generate dataclass fields.
60- field_makers : Final = {"dataclasses.field" }
6162
6263
6364SELF_TVAR_NAME : Final = "_DT"
65+ _TRANSFORM_SPEC_FOR_DATACLASSES = DataclassTransformSpec (
66+ eq_default = True ,
67+ order_default = False ,
68+ kw_only_default = False ,
69+ frozen_default = False ,
70+ field_specifiers = ("dataclasses.Field" , "dataclasses.field" ),
71+ )
6472
6573
6674class DataclassAttribute :
@@ -155,6 +163,7 @@ class DataclassTransformer:
155163
156164 def __init__ (self , ctx : ClassDefContext ) -> None :
157165 self ._ctx = ctx
166+ self ._spec = _get_transform_spec (ctx .reason )
158167
159168 def transform (self ) -> bool :
160169 """Apply all the necessary transformations to the underlying
@@ -172,9 +181,9 @@ def transform(self) -> bool:
172181 return False
173182 decorator_arguments = {
174183 "init" : _get_decorator_bool_argument (self ._ctx , "init" , True ),
175- "eq" : _get_decorator_bool_argument (self ._ctx , "eq" , True ),
176- "order" : _get_decorator_bool_argument (self ._ctx , "order" , False ),
177- "frozen" : _get_decorator_bool_argument (self ._ctx , "frozen" , False ),
184+ "eq" : _get_decorator_bool_argument (self ._ctx , "eq" , self . _spec . eq_default ),
185+ "order" : _get_decorator_bool_argument (self ._ctx , "order" , self . _spec . order_default ),
186+ "frozen" : _get_decorator_bool_argument (self ._ctx , "frozen" , self . _spec . frozen_default ),
178187 "slots" : _get_decorator_bool_argument (self ._ctx , "slots" , False ),
179188 "match_args" : _get_decorator_bool_argument (self ._ctx , "match_args" , True ),
180189 }
@@ -411,7 +420,7 @@ def collect_attributes(self) -> list[DataclassAttribute] | None:
411420
412421 # Second, collect attributes belonging to the current class.
413422 current_attr_names : set [str ] = set ()
414- kw_only = _get_decorator_bool_argument (ctx , "kw_only" , False )
423+ kw_only = _get_decorator_bool_argument (ctx , "kw_only" , self . _spec . kw_only_default )
415424 for stmt in cls .defs .body :
416425 # Any assignment that doesn't use the new type declaration
417426 # syntax can be ignored out of hand.
@@ -461,7 +470,7 @@ def collect_attributes(self) -> list[DataclassAttribute] | None:
461470 if self ._is_kw_only_type (node_type ):
462471 kw_only = True
463472
464- has_field_call , field_args = _collect_field_args (stmt .rvalue , ctx )
473+ has_field_call , field_args = self . _collect_field_args (stmt .rvalue , ctx )
465474
466475 is_in_init_param = field_args .get ("init" )
467476 if is_in_init_param is None :
@@ -614,6 +623,36 @@ def _add_dataclass_fields_magic_attribute(self) -> None:
614623 kind = MDEF , node = var , plugin_generated = True
615624 )
616625
626+ def _collect_field_args (
627+ self , expr : Expression , ctx : ClassDefContext
628+ ) -> tuple [bool , dict [str , Expression ]]:
629+ """Returns a tuple where the first value represents whether or not
630+ the expression is a call to dataclass.field and the second is a
631+ dictionary of the keyword arguments that field() was called with.
632+ """
633+ if (
634+ isinstance (expr , CallExpr )
635+ and isinstance (expr .callee , RefExpr )
636+ and expr .callee .fullname in self ._spec .field_specifiers
637+ ):
638+ # field() only takes keyword arguments.
639+ args = {}
640+ for name , arg , kind in zip (expr .arg_names , expr .args , expr .arg_kinds ):
641+ if not kind .is_named ():
642+ if kind .is_named (star = True ):
643+ # This means that `field` is used with `**` unpacking,
644+ # the best we can do for now is not to fail.
645+ # TODO: we can infer what's inside `**` and try to collect it.
646+ message = 'Unpacking **kwargs in "field()" is not supported'
647+ else :
648+ message = '"field()" does not accept positional arguments'
649+ ctx .api .fail (message , expr )
650+ return True , {}
651+ assert name is not None
652+ args [name ] = arg
653+ return True , args
654+ return False , {}
655+
617656
618657def dataclass_tag_callback (ctx : ClassDefContext ) -> None :
619658 """Record that we have a dataclass in the main semantic analysis pass.
@@ -631,32 +670,29 @@ def dataclass_class_maker_callback(ctx: ClassDefContext) -> bool:
631670 return transformer .transform ()
632671
633672
634- def _collect_field_args (
635- expr : Expression , ctx : ClassDefContext
636- ) -> tuple [bool , dict [str , Expression ]]:
637- """Returns a tuple where the first value represents whether or not
638- the expression is a call to dataclass.field and the second is a
639- dictionary of the keyword arguments that field() was called with.
673+ def _get_transform_spec (reason : Expression ) -> DataclassTransformSpec :
674+ """Find the relevant transform parameters from the decorator/parent class/metaclass that
675+ triggered the dataclasses plugin.
676+
677+ Although the resulting DataclassTransformSpec is based on the typing.dataclass_transform
678+ function, we also use it for traditional dataclasses.dataclass classes as well for simplicity.
679+ In those cases, we return a default spec rather than one based on a call to
680+ `typing.dataclass_transform`.
640681 """
641- if (
642- isinstance (expr , CallExpr )
643- and isinstance (expr .callee , RefExpr )
644- and expr .callee .fullname in field_makers
645- ):
646- # field() only takes keyword arguments.
647- args = {}
648- for name , arg , kind in zip (expr .arg_names , expr .args , expr .arg_kinds ):
649- if not kind .is_named ():
650- if kind .is_named (star = True ):
651- # This means that `field` is used with `**` unpacking,
652- # the best we can do for now is not to fail.
653- # TODO: we can infer what's inside `**` and try to collect it.
654- message = 'Unpacking **kwargs in "field()" is not supported'
655- else :
656- message = '"field()" does not accept positional arguments'
657- ctx .api .fail (message , expr )
658- return True , {}
659- assert name is not None
660- args [name ] = arg
661- return True , args
662- return False , {}
682+ if _is_dataclasses_decorator (reason ):
683+ return _TRANSFORM_SPEC_FOR_DATACLASSES
684+
685+ spec = find_dataclass_transform_spec (reason )
686+ assert spec is not None , (
687+ "trying to find dataclass transform spec, but reason is neither dataclasses.dataclass nor "
688+ "decorated with typing.dataclass_transform"
689+ )
690+ return spec
691+
692+
693+ def _is_dataclasses_decorator (node : Node ) -> bool :
694+ if isinstance (node , CallExpr ):
695+ node = node .callee
696+ if isinstance (node , RefExpr ):
697+ return node .fullname in dataclass_makers
698+ return False
0 commit comments