44from typing import Any
55
66import numpy as np
7+ from pandas .api .types import is_extension_array_dtype
78
8- from xarray .core import utils
9+ from xarray .core import npcompat , utils
910
1011# Use as a sentinel value to indicate a dtype appropriate NA value.
1112NA = utils .ReprObject ("<NA>" )
@@ -60,22 +61,22 @@ def maybe_promote(dtype: np.dtype) -> tuple[np.dtype, Any]:
6061 # N.B. these casting rules should match pandas
6162 dtype_ : np .typing .DTypeLike
6263 fill_value : Any
63- if np . issubdtype (dtype , np . floating ):
64+ if isdtype (dtype , "real floating" ):
6465 dtype_ = dtype
6566 fill_value = np .nan
66- elif np .issubdtype (dtype , np .timedelta64 ):
67+ elif isinstance ( dtype , np . dtype ) and np .issubdtype (dtype , np .timedelta64 ):
6768 # See https://github.com/numpy/numpy/issues/10685
6869 # np.timedelta64 is a subclass of np.integer
6970 # Check np.timedelta64 before np.integer
7071 fill_value = np .timedelta64 ("NaT" )
7172 dtype_ = dtype
72- elif np . issubdtype (dtype , np . integer ):
73+ elif isdtype (dtype , "integral" ):
7374 dtype_ = np .float32 if dtype .itemsize <= 2 else np .float64
7475 fill_value = np .nan
75- elif np . issubdtype (dtype , np . complexfloating ):
76+ elif isdtype (dtype , "complex floating" ):
7677 dtype_ = dtype
7778 fill_value = np .nan + np .nan * 1j
78- elif np .issubdtype (dtype , np .datetime64 ):
79+ elif isinstance ( dtype , np . dtype ) and np .issubdtype (dtype , np .datetime64 ):
7980 dtype_ = dtype
8081 fill_value = np .datetime64 ("NaT" )
8182 else :
@@ -118,16 +119,16 @@ def get_pos_infinity(dtype, max_for_int=False):
118119 -------
119120 fill_value : positive infinity value corresponding to this dtype.
120121 """
121- if issubclass (dtype . type , np . floating ):
122+ if isdtype (dtype , "real floating" ):
122123 return np .inf
123124
124- if issubclass (dtype . type , np . integer ):
125+ if isdtype (dtype , "integral" ):
125126 if max_for_int :
126127 return np .iinfo (dtype ).max
127128 else :
128129 return np .inf
129130
130- if issubclass (dtype . type , np . complexfloating ):
131+ if isdtype (dtype , "complex floating" ):
131132 return np .inf + 1j * np .inf
132133
133134 return INF
@@ -146,24 +147,66 @@ def get_neg_infinity(dtype, min_for_int=False):
146147 -------
147148 fill_value : positive infinity value corresponding to this dtype.
148149 """
149- if issubclass (dtype . type , np . floating ):
150+ if isdtype (dtype , "real floating" ):
150151 return - np .inf
151152
152- if issubclass (dtype . type , np . integer ):
153+ if isdtype (dtype , "integral" ):
153154 if min_for_int :
154155 return np .iinfo (dtype ).min
155156 else :
156157 return - np .inf
157158
158- if issubclass (dtype . type , np . complexfloating ):
159+ if isdtype (dtype , "complex floating" ):
159160 return - np .inf - 1j * np .inf
160161
161162 return NINF
162163
163164
164- def is_datetime_like (dtype ):
165+ def is_datetime_like (dtype ) -> bool :
165166 """Check if a dtype is a subclass of the numpy datetime types"""
166- return np .issubdtype (dtype , np .datetime64 ) or np .issubdtype (dtype , np .timedelta64 )
167+ return _is_numpy_subdtype (dtype , (np .datetime64 , np .timedelta64 ))
168+
169+
170+ def is_object (dtype ) -> bool :
171+ """Check if a dtype is object"""
172+ return _is_numpy_subdtype (dtype , object )
173+
174+
175+ def is_string (dtype ) -> bool :
176+ """Check if a dtype is a string dtype"""
177+ return _is_numpy_subdtype (dtype , (np .str_ , np .character ))
178+
179+
180+ def _is_numpy_subdtype (dtype , kind ) -> bool :
181+ if not isinstance (dtype , np .dtype ):
182+ return False
183+
184+ kinds = kind if isinstance (kind , tuple ) else (kind ,)
185+ return any (np .issubdtype (dtype , kind ) for kind in kinds )
186+
187+
188+ def isdtype (dtype , kind : str | tuple [str , ...], xp = None ) -> bool :
189+ """Compatibility wrapper for isdtype() from the array API standard.
190+
191+ Unlike xp.isdtype(), kind must be a string.
192+ """
193+ # TODO(shoyer): remove this wrapper when Xarray requires
194+ # numpy>=2 and pandas extensions arrays are implemented in
195+ # Xarray via the array API
196+ if not isinstance (kind , str ) and not (
197+ isinstance (kind , tuple ) and all (isinstance (k , str ) for k in kind )
198+ ):
199+ raise TypeError (f"kind must be a string or a tuple of strings: { repr (kind )} " )
200+
201+ if isinstance (dtype , np .dtype ):
202+ return npcompat .isdtype (dtype , kind )
203+ elif is_extension_array_dtype (dtype ):
204+ # we never want to match pandas extension array dtypes
205+ return False
206+ else :
207+ if xp is None :
208+ xp = np
209+ return xp .isdtype (dtype , kind )
167210
168211
169212def result_type (
@@ -184,12 +227,26 @@ def result_type(
184227 -------
185228 numpy.dtype for the result.
186229 """
187- types = {np .result_type (t ).type for t in arrays_and_dtypes }
230+ from xarray .core .duck_array_ops import get_array_namespace
231+
232+ # TODO(shoyer): consider moving this logic into get_array_namespace()
233+ # or another helper function.
234+ namespaces = {get_array_namespace (t ) for t in arrays_and_dtypes }
235+ non_numpy = namespaces - {np }
236+ if non_numpy :
237+ [xp ] = non_numpy
238+ else :
239+ xp = np
240+
241+ types = {xp .result_type (t ) for t in arrays_and_dtypes }
188242
189- for left , right in PROMOTE_TO_OBJECT :
190- if any (issubclass (t , left ) for t in types ) and any (
191- issubclass (t , right ) for t in types
192- ):
193- return np .dtype (object )
243+ if any (isinstance (t , np .dtype ) for t in types ):
244+ # only check if there's numpy dtypes – the array API does not
245+ # define the types we're checking for
246+ for left , right in PROMOTE_TO_OBJECT :
247+ if any (np .issubdtype (t , left ) for t in types ) and any (
248+ np .issubdtype (t , right ) for t in types
249+ ):
250+ return xp .dtype (object )
194251
195- return np .result_type (* arrays_and_dtypes )
252+ return xp .result_type (* arrays_and_dtypes )
0 commit comments