@@ -1661,6 +1661,40 @@ def take(arr, indices, axis: int = 0, allow_fill: bool = False, fill_value=None)
16611661 return result
16621662
16631663
1664+ def _take_preprocess_indexer_and_fill_value (
1665+ arr , indexer , axis , out , fill_value , allow_fill
1666+ ):
1667+ mask_info = None
1668+
1669+ if indexer is None :
1670+ indexer = np .arange (arr .shape [axis ], dtype = np .int64 )
1671+ dtype , fill_value = arr .dtype , arr .dtype .type ()
1672+ else :
1673+ indexer = ensure_int64 (indexer , copy = False )
1674+ if not allow_fill :
1675+ dtype , fill_value = arr .dtype , arr .dtype .type ()
1676+ mask_info = None , False
1677+ else :
1678+ # check for promotion based on types only (do this first because
1679+ # it's faster than computing a mask)
1680+ dtype , fill_value = maybe_promote (arr .dtype , fill_value )
1681+ if dtype != arr .dtype and (out is None or out .dtype != dtype ):
1682+ # check if promotion is actually required based on indexer
1683+ mask = indexer == - 1
1684+ needs_masking = mask .any ()
1685+ mask_info = mask , needs_masking
1686+ if needs_masking :
1687+ if out is not None and out .dtype != dtype :
1688+ raise TypeError ("Incompatible type for fill_value" )
1689+ else :
1690+ # if not, then depromote, set fill_value to dummy
1691+ # (it won't be used but we don't want the cython code
1692+ # to crash when trying to cast it to dtype)
1693+ dtype , fill_value = arr .dtype , arr .dtype .type ()
1694+
1695+ return indexer , dtype , fill_value , mask_info
1696+
1697+
16641698def take_nd (
16651699 arr ,
16661700 indexer ,
@@ -1700,8 +1734,6 @@ def take_nd(
17001734 subarray : array-like
17011735 May be the same type as the input, or cast to an ndarray.
17021736 """
1703- mask_info = None
1704-
17051737 if fill_value is lib .no_default :
17061738 fill_value = na_value_for_dtype (arr .dtype , compat = False )
17071739
@@ -1712,31 +1744,9 @@ def take_nd(
17121744 arr = extract_array (arr )
17131745 arr = np .asarray (arr )
17141746
1715- if indexer is None :
1716- indexer = np .arange (arr .shape [axis ], dtype = np .int64 )
1717- dtype , fill_value = arr .dtype , arr .dtype .type ()
1718- else :
1719- indexer = ensure_int64 (indexer , copy = False )
1720- if not allow_fill :
1721- dtype , fill_value = arr .dtype , arr .dtype .type ()
1722- mask_info = None , False
1723- else :
1724- # check for promotion based on types only (do this first because
1725- # it's faster than computing a mask)
1726- dtype , fill_value = maybe_promote (arr .dtype , fill_value )
1727- if dtype != arr .dtype and (out is None or out .dtype != dtype ):
1728- # check if promotion is actually required based on indexer
1729- mask = indexer == - 1
1730- needs_masking = mask .any ()
1731- mask_info = mask , needs_masking
1732- if needs_masking :
1733- if out is not None and out .dtype != dtype :
1734- raise TypeError ("Incompatible type for fill_value" )
1735- else :
1736- # if not, then depromote, set fill_value to dummy
1737- # (it won't be used but we don't want the cython code
1738- # to crash when trying to cast it to dtype)
1739- dtype , fill_value = arr .dtype , arr .dtype .type ()
1747+ indexer , dtype , fill_value , mask_info = _take_preprocess_indexer_and_fill_value (
1748+ arr , indexer , axis , out , fill_value , allow_fill
1749+ )
17401750
17411751 flip_order = False
17421752 if arr .ndim == 2 and arr .flags .f_contiguous :
0 commit comments