File tree Expand file tree Collapse file tree 2 files changed +12
-3
lines changed Expand file tree Collapse file tree 2 files changed +12
-3
lines changed Original file line number Diff line number Diff line change @@ -1224,7 +1224,7 @@ class DataFrame(NDFrame, OpsMixin):
12241224 @overload
12251225 def apply (
12261226 self ,
1227- f : Callable [..., S1 ],
1227+ f : Callable [..., S1 | NAType ],
12281228 axis : AxisIndex = ...,
12291229 raw : _bool = ...,
12301230 result_type : None = ...,
@@ -1248,7 +1248,7 @@ class DataFrame(NDFrame, OpsMixin):
12481248 @overload
12491249 def apply (
12501250 self ,
1251- f : Callable [..., S1 ],
1251+ f : Callable [..., S1 | NAType ],
12521252 axis : Axis = ...,
12531253 raw : _bool = ...,
12541254 args : Any = ...,
@@ -1309,7 +1309,7 @@ class DataFrame(NDFrame, OpsMixin):
13091309 @overload
13101310 def apply (
13111311 self ,
1312- f : Callable [..., S1 ],
1312+ f : Callable [..., S1 | NAType ],
13131313 raw : _bool = ...,
13141314 result_type : None = ...,
13151315 args : Any = ...,
Original file line number Diff line number Diff line change 4343)
4444import xarray as xr
4545
46+ from pandas ._libs .missing import NAType
4647from pandas ._typing import Scalar
4748
4849from tests import (
@@ -578,6 +579,9 @@ def test_types_apply() -> None:
578579 def returns_scalar (x : pd .Series ) -> int :
579580 return 2
580581
582+ def returns_scalar_na (x : pd .Series ) -> int | NAType :
583+ return 2 if (x < 5 ).all () else pd .NA
584+
581585 def returns_series (x : pd .Series ) -> pd .Series :
582586 return x ** 2
583587
@@ -604,6 +608,11 @@ def gethead(s: pd.Series, y: int) -> pd.Series:
604608 check (
605609 assert_type (df .apply (returns_scalar ), "pd.Series[int]" ), pd .Series , np .integer
606610 )
611+ check (
612+ assert_type (df .apply (returns_scalar_na ), "pd.Series[int]" ),
613+ pd .Series ,
614+ int ,
615+ )
607616 check (assert_type (df .apply (returns_series ), pd .DataFrame ), pd .DataFrame )
608617 check (assert_type (df .apply (returns_listlike_of_3 ), pd .DataFrame ), pd .DataFrame )
609618 check (assert_type (df .apply (returns_dict ), pd .Series ), pd .Series )
You can’t perform that action at this time.
0 commit comments