diff --git a/pandas/core/arrays/categorical.py b/pandas/core/arrays/categorical.py index dd5ff7781e463..737c130161246 100644 --- a/pandas/core/arrays/categorical.py +++ b/pandas/core/arrays/categorical.py @@ -26,11 +26,9 @@ is_dtype_equal, is_extension_array_dtype, is_integer_dtype, - is_iterator, is_list_like, is_object_dtype, is_scalar, - is_sequence, is_timedelta64_dtype, needs_i8_conversion, ) @@ -324,7 +322,7 @@ def __init__( # of numpy values = maybe_infer_to_datetimelike(values, convert_dates=True) if not isinstance(values, np.ndarray): - values = _convert_to_list_like(values) + values = com.convert_to_list_like(values) # By convention, empty lists result in object dtype: sanitize_dtype = np.dtype("O") if len(values) == 0 else None @@ -2647,17 +2645,6 @@ def recode_for_categories(codes: np.ndarray, old_categories, new_categories): return new_codes -def _convert_to_list_like(list_like): - if hasattr(list_like, "dtype"): - return list_like - if isinstance(list_like, list): - return list_like - if is_sequence(list_like) or isinstance(list_like, tuple) or is_iterator(list_like): - return list(list_like) - - return [list_like] - - def factorize_from_iterable(values): """ Factorize an input `values` into `categories` and `codes`. Preserves diff --git a/pandas/core/common.py b/pandas/core/common.py index bb911c0617242..1ccca5193ab46 100644 --- a/pandas/core/common.py +++ b/pandas/core/common.py @@ -4,17 +4,16 @@ Note: pandas.core.common is *not* part of the public API. """ -import collections -from collections import abc +from collections import abc, defaultdict from datetime import datetime, timedelta from functools import partial import inspect -from typing import Any, Collection, Iterable, Union +from typing import Any, Collection, Iterable, List, Union import numpy as np from pandas._libs import lib, tslibs -from pandas._typing import T +from pandas._typing import AnyArrayLike, Scalar, T from pandas.compat.numpy import _np_version_under1p17 from pandas.core.dtypes.cast import construct_1d_object_array_from_listlike @@ -24,7 +23,12 @@ is_extension_array_dtype, is_integer, ) -from pandas.core.dtypes.generic import ABCIndex, ABCIndexClass, ABCSeries +from pandas.core.dtypes.generic import ( + ABCExtensionArray, + ABCIndex, + ABCIndexClass, + ABCSeries, +) from pandas.core.dtypes.inference import _iterable_not_string from pandas.core.dtypes.missing import isna, isnull, notnull # noqa @@ -367,12 +371,12 @@ def standardize_mapping(into): Series.to_dict """ if not inspect.isclass(into): - if isinstance(into, collections.defaultdict): - return partial(collections.defaultdict, into.default_factory) + if isinstance(into, defaultdict): + return partial(defaultdict, into.default_factory) into = type(into) if not issubclass(into, abc.Mapping): raise TypeError(f"unsupported type: {into}") - elif into == collections.defaultdict: + elif into == defaultdict: raise TypeError("to_dict() only accepts initialized defaultdicts") return into @@ -473,3 +477,18 @@ def f(x): f = mapper return f + + +def convert_to_list_like( + values: Union[Scalar, Iterable, AnyArrayLike] +) -> Union[List, AnyArrayLike]: + """ + Convert list-like or scalar input to list-like. List, numpy and pandas array-like + inputs are returned unmodified whereas others are converted to list. + """ + if isinstance(values, (list, np.ndarray, ABCIndex, ABCSeries, ABCExtensionArray)): + return values + elif isinstance(values, abc.Iterable) and not isinstance(values, str): + return list(values) + + return [values]