@@ -1634,10 +1634,10 @@ def diff(arr, n: int, axis: int = 0, stacklevel=3):
16341634
16351635 Parameters
16361636 ----------
1637- arr : ndarray
1637+ arr : ndarray or ExtensionArray
16381638 n : int
16391639 number of periods
1640- axis : int
1640+ axis : {0, 1}
16411641 axis to shift on
16421642 stacklevel : int
16431643 The stacklevel for the lost dtype warning.
@@ -1651,7 +1651,8 @@ def diff(arr, n: int, axis: int = 0, stacklevel=3):
16511651 na = np .nan
16521652 dtype = arr .dtype
16531653
1654- if dtype .kind == "b" :
1654+ is_bool = is_bool_dtype (dtype )
1655+ if is_bool :
16551656 op = operator .xor
16561657 else :
16571658 op = operator .sub
@@ -1677,17 +1678,15 @@ def diff(arr, n: int, axis: int = 0, stacklevel=3):
16771678 dtype = arr .dtype
16781679
16791680 is_timedelta = False
1680- is_bool = False
16811681 if needs_i8_conversion (arr .dtype ):
16821682 dtype = np .int64
16831683 arr = arr .view ("i8" )
16841684 na = iNaT
16851685 is_timedelta = True
16861686
1687- elif is_bool_dtype ( dtype ) :
1687+ elif is_bool :
16881688 # We have to cast in order to be able to hold np.nan
16891689 dtype = np .object_
1690- is_bool = True
16911690
16921691 elif is_integer_dtype (dtype ):
16931692 # We have to cast in order to be able to hold np.nan
@@ -1708,45 +1707,26 @@ def diff(arr, n: int, axis: int = 0, stacklevel=3):
17081707 dtype = np .dtype (dtype )
17091708 out_arr = np .empty (arr .shape , dtype = dtype )
17101709
1711- na_indexer = [slice (None )] * arr . ndim
1710+ na_indexer = [slice (None )] * 2
17121711 na_indexer [axis ] = slice (None , n ) if n >= 0 else slice (n , None )
17131712 out_arr [tuple (na_indexer )] = na
17141713
1715- if arr .ndim == 2 and arr . dtype .name in _diff_special :
1714+ if arr .dtype .name in _diff_special :
17161715 # TODO: can diff_2d dtype specialization troubles be fixed by defining
17171716 # out_arr inside diff_2d?
17181717 algos .diff_2d (arr , out_arr , n , axis , datetimelike = is_timedelta )
17191718 else :
17201719 # To keep mypy happy, _res_indexer is a list while res_indexer is
17211720 # a tuple, ditto for lag_indexer.
1722- _res_indexer = [slice (None )] * arr . ndim
1721+ _res_indexer = [slice (None )] * 2
17231722 _res_indexer [axis ] = slice (n , None ) if n >= 0 else slice (None , n )
17241723 res_indexer = tuple (_res_indexer )
17251724
1726- _lag_indexer = [slice (None )] * arr . ndim
1725+ _lag_indexer = [slice (None )] * 2
17271726 _lag_indexer [axis ] = slice (None , - n ) if n > 0 else slice (- n , None )
17281727 lag_indexer = tuple (_lag_indexer )
17291728
1730- # need to make sure that we account for na for datelike/timedelta
1731- # we don't actually want to subtract these i8 numbers
1732- if is_timedelta :
1733- res = arr [res_indexer ]
1734- lag = arr [lag_indexer ]
1735-
1736- mask = (arr [res_indexer ] == na ) | (arr [lag_indexer ] == na )
1737- if mask .any ():
1738- res = res .copy ()
1739- res [mask ] = 0
1740- lag = lag .copy ()
1741- lag [mask ] = 0
1742-
1743- result = res - lag
1744- result [mask ] = na
1745- out_arr [res_indexer ] = result
1746- elif is_bool :
1747- out_arr [res_indexer ] = arr [res_indexer ] ^ arr [lag_indexer ]
1748- else :
1749- out_arr [res_indexer ] = arr [res_indexer ] - arr [lag_indexer ]
1729+ out_arr [res_indexer ] = op (arr [res_indexer ], arr [lag_indexer ])
17501730
17511731 if is_timedelta :
17521732 out_arr = out_arr .view ("timedelta64[ns]" )
0 commit comments