1414 Sequence ,
1515 cast ,
1616 final ,
17+ overload ,
1718)
1819import warnings
1920
101102 Categorical ,
102103 DataFrame ,
103104 Index ,
105+ MultiIndex ,
104106 Series ,
105107 )
106108 from pandas .core .arrays import (
@@ -1780,7 +1782,7 @@ def safe_sort(
17801782 na_sentinel : int = - 1 ,
17811783 assume_unique : bool = False ,
17821784 verify : bool = True ,
1783- ) -> np .ndarray | tuple [np .ndarray , np .ndarray ]:
1785+ ) -> np .ndarray | MultiIndex | tuple [np .ndarray | MultiIndex , np .ndarray ]:
17841786 """
17851787 Sort ``values`` and reorder corresponding ``codes``.
17861788
@@ -1809,7 +1811,7 @@ def safe_sort(
18091811
18101812 Returns
18111813 -------
1812- ordered : ndarray
1814+ ordered : ndarray or MultiIndex
18131815 Sorted ``values``
18141816 new_codes : ndarray
18151817 Reordered ``codes``; returned when ``codes`` is not None.
@@ -1827,6 +1829,7 @@ def safe_sort(
18271829 raise TypeError (
18281830 "Only list-like objects are allowed to be passed to safe_sort as values"
18291831 )
1832+ original_values = values
18301833
18311834 if not isinstance (values , (np .ndarray , ABCExtensionArray )):
18321835 # don't convert to string types
@@ -1838,6 +1841,7 @@ def safe_sort(
18381841 values = np .asarray (values , dtype = dtype ) # type: ignore[arg-type]
18391842
18401843 sorter = None
1844+ ordered : np .ndarray | MultiIndex
18411845
18421846 if (
18431847 not is_extension_array_dtype (values )
@@ -1853,7 +1857,7 @@ def safe_sort(
18531857 # which would work, but which fails for special case of 1d arrays
18541858 # with tuples.
18551859 if values .size and isinstance (values [0 ], tuple ):
1856- ordered = _sort_tuples (values )
1860+ ordered = _sort_tuples (values , original_values )
18571861 else :
18581862 ordered = _sort_mixed (values )
18591863
@@ -1915,19 +1919,33 @@ def _sort_mixed(values) -> np.ndarray:
19151919 )
19161920
19171921
1918- def _sort_tuples (values : np .ndarray ) -> np .ndarray :
1922+ @overload
1923+ def _sort_tuples (values : np .ndarray , original_values : np .ndarray ) -> np .ndarray :
1924+ ...
1925+
1926+
1927+ @overload
1928+ def _sort_tuples (values : np .ndarray , original_values : MultiIndex ) -> MultiIndex :
1929+ ...
1930+
1931+
1932+ def _sort_tuples (
1933+ values : np .ndarray , original_values : np .ndarray | MultiIndex
1934+ ) -> np .ndarray | MultiIndex :
19191935 """
19201936 Convert array of tuples (1d) to array or array (2d).
19211937 We need to keep the columns separately as they contain different types and
19221938 nans (can't use `np.sort` as it may fail when str and nan are mixed in a
19231939 column as types cannot be compared).
1940+ We have to apply the indexer to the original values to keep the dtypes in
1941+ case of MultiIndexes
19241942 """
19251943 from pandas .core .internals .construction import to_arrays
19261944 from pandas .core .sorting import lexsort_indexer
19271945
19281946 arrays , _ = to_arrays (values , None )
19291947 indexer = lexsort_indexer (arrays , orders = True )
1930- return values [indexer ]
1948+ return original_values [indexer ]
19311949
19321950
19331951def union_with_duplicates (lvals : ArrayLike , rvals : ArrayLike ) -> ArrayLike :
0 commit comments