|
7 | 7 | import numpy as np |
8 | 8 |
|
9 | 9 | from pandas.core.dtypes.base import ExtensionDtype |
| 10 | +from pandas.core.dtypes.cast import maybe_cast_to_extension_array |
10 | 11 | from pandas.core.dtypes.common import is_dtype_equal, is_list_like, pandas_dtype |
11 | 12 |
|
12 | 13 | import pandas as pd |
13 | 14 | from pandas.api.extensions import no_default, register_extension_dtype |
14 | 15 | from pandas.core.arraylike import OpsMixin |
15 | | -from pandas.core.arrays import ExtensionArray, ExtensionScalarOpsMixin |
| 16 | +from pandas.core.arrays import ExtensionArray |
16 | 17 | from pandas.core.indexers import check_array_indexer |
17 | 18 |
|
18 | 19 |
|
@@ -45,7 +46,7 @@ def _is_numeric(self) -> bool: |
45 | 46 | return True |
46 | 47 |
|
47 | 48 |
|
48 | | -class DecimalArray(OpsMixin, ExtensionScalarOpsMixin, ExtensionArray): |
| 49 | +class DecimalArray(OpsMixin, ExtensionArray): |
49 | 50 | __array_priority__ = 1000 |
50 | 51 |
|
51 | 52 | def __init__(self, values, dtype=None, copy=False, context=None): |
@@ -225,13 +226,46 @@ def convert_values(param): |
225 | 226 |
|
226 | 227 | return np.asarray(res, dtype=bool) |
227 | 228 |
|
| 229 | + _do_coerce = True # overriden in DecimalArrayWithoutCoercion |
| 230 | + |
| 231 | + def _arith_method(self, other, op): |
| 232 | + def convert_values(param): |
| 233 | + if isinstance(param, ExtensionArray) or is_list_like(param): |
| 234 | + ovalues = param |
| 235 | + else: # Assume its an object |
| 236 | + ovalues = [param] * len(self) |
| 237 | + return ovalues |
| 238 | + |
| 239 | + lvalues = self |
| 240 | + rvalues = convert_values(other) |
| 241 | + |
| 242 | + # If the operator is not defined for the underlying objects, |
| 243 | + # a TypeError should be raised |
| 244 | + res = [op(a, b) for (a, b) in zip(lvalues, rvalues)] |
| 245 | + |
| 246 | + def _maybe_convert(arr): |
| 247 | + if self._do_coerce: |
| 248 | + # https://github.com/pandas-dev/pandas/issues/22850 |
| 249 | + # We catch all regular exceptions here, and fall back |
| 250 | + # to an ndarray. |
| 251 | + res = maybe_cast_to_extension_array(type(self), arr) |
| 252 | + if not isinstance(res, type(self)): |
| 253 | + # exception raised in _from_sequence; ensure we have ndarray |
| 254 | + res = np.asarray(arr) |
| 255 | + else: |
| 256 | + res = np.asarray(arr) |
| 257 | + return res |
| 258 | + |
| 259 | + if op.__name__ in {"divmod", "rdivmod"}: |
| 260 | + a, b = zip(*res) |
| 261 | + return _maybe_convert(a), _maybe_convert(b) |
| 262 | + |
| 263 | + return _maybe_convert(res) |
| 264 | + |
228 | 265 |
|
229 | 266 | def to_decimal(values, context=None): |
230 | 267 | return DecimalArray([decimal.Decimal(x) for x in values], context=context) |
231 | 268 |
|
232 | 269 |
|
233 | 270 | def make_data(): |
234 | 271 | return [decimal.Decimal(random.random()) for _ in range(100)] |
235 | | - |
236 | | - |
237 | | -DecimalArray._add_arithmetic_ops() |
|
0 commit comments