66from mypy .plugin import (
77 Plugin , FunctionContext , MethodContext , MethodSigContext , AttributeContext , ClassDefContext
88)
9+ from mypy .plugins .common import try_getting_str_literal
910from mypy .types import (
1011 Type , Instance , AnyType , TypeOfAny , CallableType , NoneTyp , UnionType , TypedDictType ,
1112 TypeVarType
@@ -170,24 +171,26 @@ def typed_dict_get_callback(ctx: MethodContext) -> Type:
170171 if (isinstance (ctx .type , TypedDictType )
171172 and len (ctx .arg_types ) >= 1
172173 and len (ctx .arg_types [0 ]) == 1 ):
173- if isinstance (ctx .args [0 ][0 ], StrExpr ):
174- key = ctx .args [0 ][0 ].value
175- value_type = ctx .type .items .get (key )
176- if value_type :
177- if len (ctx .arg_types ) == 1 :
178- return UnionType .make_simplified_union ([value_type , NoneTyp ()])
179- elif (len (ctx .arg_types ) == 2 and len (ctx .arg_types [1 ]) == 1
180- and len (ctx .args [1 ]) == 1 ):
181- default_arg = ctx .args [1 ][0 ]
182- if (isinstance (default_arg , DictExpr ) and len (default_arg .items ) == 0
183- and isinstance (value_type , TypedDictType )):
184- # Special case '{}' as the default for a typed dict type.
185- return value_type .copy_modified (required_keys = set ())
186- else :
187- return UnionType .make_simplified_union ([value_type , ctx .arg_types [1 ][0 ]])
188- else :
189- ctx .api .msg .typeddict_key_not_found (ctx .type , key , ctx .context )
190- return AnyType (TypeOfAny .from_error )
174+ key = try_getting_str_literal (ctx .args [0 ][0 ], ctx .arg_types [0 ][0 ])
175+ if key is None :
176+ return ctx .default_return_type
177+
178+ value_type = ctx .type .items .get (key )
179+ if value_type :
180+ if len (ctx .arg_types ) == 1 :
181+ return UnionType .make_simplified_union ([value_type , NoneTyp ()])
182+ elif (len (ctx .arg_types ) == 2 and len (ctx .arg_types [1 ]) == 1
183+ and len (ctx .args [1 ]) == 1 ):
184+ default_arg = ctx .args [1 ][0 ]
185+ if (isinstance (default_arg , DictExpr ) and len (default_arg .items ) == 0
186+ and isinstance (value_type , TypedDictType )):
187+ # Special case '{}' as the default for a typed dict type.
188+ return value_type .copy_modified (required_keys = set ())
189+ else :
190+ return UnionType .make_simplified_union ([value_type , ctx .arg_types [1 ][0 ]])
191+ else :
192+ ctx .api .msg .typeddict_key_not_found (ctx .type , key , ctx .context )
193+ return AnyType (TypeOfAny .from_error )
191194 return ctx .default_return_type
192195
193196
@@ -225,23 +228,23 @@ def typed_dict_pop_callback(ctx: MethodContext) -> Type:
225228 if (isinstance (ctx .type , TypedDictType )
226229 and len (ctx .arg_types ) >= 1
227230 and len (ctx .arg_types [0 ]) == 1 ):
228- if isinstance (ctx .args [0 ][0 ], StrExpr ):
229- key = ctx .args [0 ][0 ].value
230- if key in ctx .type .required_keys :
231- ctx .api .msg .typeddict_key_cannot_be_deleted (ctx .type , key , ctx .context )
232- value_type = ctx .type .items .get (key )
233- if value_type :
234- if len (ctx .args [1 ]) == 0 :
235- return value_type
236- elif (len (ctx .arg_types ) == 2 and len (ctx .arg_types [1 ]) == 1
237- and len (ctx .args [1 ]) == 1 ):
238- return UnionType .make_simplified_union ([value_type , ctx .arg_types [1 ][0 ]])
239- else :
240- ctx .api .msg .typeddict_key_not_found (ctx .type , key , ctx .context )
241- return AnyType (TypeOfAny .from_error )
242- else :
231+ key = try_getting_str_literal (ctx .args [0 ][0 ], ctx .arg_types [0 ][0 ])
232+ if key is None :
243233 ctx .api .fail (messages .TYPEDDICT_KEY_MUST_BE_STRING_LITERAL , ctx .context )
244234 return AnyType (TypeOfAny .from_error )
235+
236+ if key in ctx .type .required_keys :
237+ ctx .api .msg .typeddict_key_cannot_be_deleted (ctx .type , key , ctx .context )
238+ value_type = ctx .type .items .get (key )
239+ if value_type :
240+ if len (ctx .args [1 ]) == 0 :
241+ return value_type
242+ elif (len (ctx .arg_types ) == 2 and len (ctx .arg_types [1 ]) == 1
243+ and len (ctx .args [1 ]) == 1 ):
244+ return UnionType .make_simplified_union ([value_type , ctx .arg_types [1 ][0 ]])
245+ else :
246+ ctx .api .msg .typeddict_key_not_found (ctx .type , key , ctx .context )
247+ return AnyType (TypeOfAny .from_error )
245248 return ctx .default_return_type
246249
247250
@@ -271,17 +274,17 @@ def typed_dict_setdefault_callback(ctx: MethodContext) -> Type:
271274 if (isinstance (ctx .type , TypedDictType )
272275 and len (ctx .arg_types ) == 2
273276 and len (ctx .arg_types [0 ]) == 1 ):
274- if isinstance (ctx .args [0 ][0 ], StrExpr ):
275- key = ctx .args [0 ][0 ].value
276- value_type = ctx .type .items .get (key )
277- if value_type :
278- return value_type
279- else :
280- ctx .api .msg .typeddict_key_not_found (ctx .type , key , ctx .context )
281- return AnyType (TypeOfAny .from_error )
282- else :
277+ key = try_getting_str_literal (ctx .args [0 ][0 ], ctx .arg_types [0 ][0 ])
278+ if key is None :
283279 ctx .api .fail (messages .TYPEDDICT_KEY_MUST_BE_STRING_LITERAL , ctx .context )
284280 return AnyType (TypeOfAny .from_error )
281+
282+ value_type = ctx .type .items .get (key )
283+ if value_type :
284+ return value_type
285+ else :
286+ ctx .api .msg .typeddict_key_not_found (ctx .type , key , ctx .context )
287+ return AnyType (TypeOfAny .from_error )
285288 return ctx .default_return_type
286289
287290
@@ -296,15 +299,15 @@ def typed_dict_delitem_callback(ctx: MethodContext) -> Type:
296299 if (isinstance (ctx .type , TypedDictType )
297300 and len (ctx .arg_types ) == 1
298301 and len (ctx .arg_types [0 ]) == 1 ):
299- if isinstance (ctx .args [0 ][0 ], StrExpr ):
300- key = ctx .args [0 ][0 ].value
301- if key in ctx .type .required_keys :
302- ctx .api .msg .typeddict_key_cannot_be_deleted (ctx .type , key , ctx .context )
303- elif key not in ctx .type .items :
304- ctx .api .msg .typeddict_key_not_found (ctx .type , key , ctx .context )
305- else :
302+ key = try_getting_str_literal (ctx .args [0 ][0 ], ctx .arg_types [0 ][0 ])
303+ if key is None :
306304 ctx .api .fail (messages .TYPEDDICT_KEY_MUST_BE_STRING_LITERAL , ctx .context )
307305 return AnyType (TypeOfAny .from_error )
306+
307+ if key in ctx .type .required_keys :
308+ ctx .api .msg .typeddict_key_cannot_be_deleted (ctx .type , key , ctx .context )
309+ elif key not in ctx .type .items :
310+ ctx .api .msg .typeddict_key_not_found (ctx .type , key , ctx .context )
308311 return ctx .default_return_type
309312
310313
0 commit comments