Skip to content

Commit 5108b02

Browse files
committed
Move everything to CFTimedeltaCoder; reuse code where possible
1 parent a21b137 commit 5108b02

File tree

2 files changed

+39
-110
lines changed

2 files changed

+39
-110
lines changed

xarray/coding/times.py

Lines changed: 38 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
unpack_for_decoding,
2222
unpack_for_encoding,
2323
)
24-
from xarray.core import duck_array_ops, indexing
24+
from xarray.core import indexing
2525
from xarray.core.common import contains_cftime_datetimes, is_np_datetime_like
2626
from xarray.core.duck_array_ops import array_all, array_any, asarray, ravel, reshape
2727
from xarray.core.formatting import first_n_items, format_timestamp, last_item
@@ -1400,6 +1400,7 @@ def encode(self, variable: Variable, name: T_Name = None) -> Variable:
14001400
dims, data, attrs, encoding = unpack_for_encoding(variable)
14011401
if "units" in encoding and not has_timedelta64_encoding_dtype(encoding):
14021402
dtype = encoding.pop("dtype", None)
1403+
units = encoding.pop("units", None)
14031404

14041405
# in the case of packed data we need to encode into
14051406
# float first, the correct dtype will be established
@@ -1409,124 +1410,54 @@ def encode(self, variable: Variable, name: T_Name = None) -> Variable:
14091410
set_dtype_encoding = dtype
14101411
dtype = data.dtype if data.dtype.kind == "f" else "float64"
14111412

1412-
data, units = encode_cf_timedelta(
1413-
data, encoding.pop("units", None), dtype
1414-
)
1415-
14161413
# retain dtype for packed data
14171414
if set_dtype_encoding is not None:
14181415
safe_setitem(encoding, "dtype", set_dtype_encoding, name=name)
1419-
1420-
safe_setitem(attrs, "units", units, name=name)
1421-
1422-
return Variable(dims, data, attrs, encoding, fastpath=True)
14231416
else:
1424-
return variable
1417+
resolution, _ = np.datetime_data(variable.dtype)
1418+
dtype = np.int64
1419+
attrs_dtype = f"timedelta64[{resolution}]"
1420+
units = _numpy_dtype_to_netcdf_timeunit(variable.dtype)
1421+
safe_setitem(attrs, "dtype", attrs_dtype, name=name)
1422+
# Remove dtype encoding if it exists to prevent it from
1423+
# interfering downstream in NonStringCoder.
1424+
encoding.pop("dtype", None)
1425+
data, units = encode_cf_timedelta(data, units, dtype)
1426+
safe_setitem(attrs, "units", units, name=name)
1427+
return Variable(dims, data, attrs, encoding, fastpath=True)
14251428
else:
14261429
return variable
14271430

14281431
def decode(self, variable: Variable, name: T_Name = None) -> Variable:
14291432
units = variable.attrs.get("units", None)
1430-
if (
1431-
isinstance(units, str)
1432-
and units in TIME_UNITS
1433-
and not has_timedelta64_encoding_dtype(variable.attrs)
1434-
):
1435-
if self._emit_decode_timedelta_future_warning:
1436-
emit_user_level_warning(
1437-
"In a future version of xarray decode_timedelta will "
1438-
"default to False rather than None. To silence this "
1439-
"warning, set decode_timedelta to True, False, or a "
1440-
"'CFTimedeltaCoder' instance.",
1441-
FutureWarning,
1442-
)
1433+
if isinstance(units, str) and units in TIME_UNITS:
14431434
dims, data, attrs, encoding = unpack_for_decoding(variable)
1444-
14451435
units = pop_to(attrs, encoding, "units")
1446-
dtype = np.dtype(f"timedelta64[{self.time_unit}]")
1447-
transform = partial(
1448-
decode_cf_timedelta, units=units, time_unit=self.time_unit
1449-
)
1436+
if has_timedelta64_encoding_dtype(variable.attrs):
1437+
dtype = pop_to(attrs, encoding, "dtype", name=name)
1438+
dtype = np.dtype(dtype)
1439+
resolution, _ = np.datetime_data(dtype)
1440+
if resolution not in typing.get_args(PDDatetimeUnitOptions):
1441+
raise ValueError(
1442+
f"Following pandas, xarray only supports decoding to "
1443+
f"timedelta64 values with a resolution of 's', 'ms', "
1444+
f"'us', or 'ns'. Encoded values have a resolution of "
1445+
f"{resolution!r}."
1446+
)
1447+
time_unit = resolution
1448+
else:
1449+
if self._emit_decode_timedelta_future_warning:
1450+
emit_user_level_warning(
1451+
"In a future version of xarray decode_timedelta will "
1452+
"default to False rather than None. To silence this "
1453+
"warning, set decode_timedelta to True, False, or a "
1454+
"'CFTimedeltaCoder' instance.",
1455+
FutureWarning,
1456+
)
1457+
dtype = np.dtype(f"timedelta64[{self.time_unit}]")
1458+
time_unit = self.time_unit
1459+
transform = partial(decode_cf_timedelta, units=units, time_unit=time_unit)
14501460
data = lazy_elemwise_func(data, transform, dtype=dtype)
1451-
1452-
return Variable(dims, data, attrs, encoding, fastpath=True)
1453-
else:
1454-
return variable
1455-
1456-
1457-
class Timedelta64TypeArray(indexing.ExplicitlyIndexedNDArrayMixin):
1458-
"""Decode arrays on the fly from integer to np.timedelta64 datatype
1459-
1460-
This is useful for decoding timedelta64 arrays from integer typed netCDF
1461-
variables.
1462-
1463-
>>> x = np.array([1, 0, 1, 1, 0], dtype="int64")
1464-
1465-
>>> x.dtype
1466-
dtype('int64')
1467-
1468-
>>> Timedelta64TypeArray(x, np.dtype("timedelta64[ns]")).dtype
1469-
dtype('<m8[ns]')
1470-
1471-
>>> indexer = indexing.BasicIndexer((slice(None),))
1472-
>>> Timedelta64TypeArray(x, np.dtype("timedelta64[ns]"))[indexer].dtype
1473-
dtype('<m8[ns]')
1474-
"""
1475-
1476-
__slots__ = ("_dtype", "array")
1477-
1478-
def __init__(self, array, dtype: np.typing.DTypeLike) -> None:
1479-
self.array = indexing.as_indexable(array)
1480-
self._dtype = dtype
1481-
1482-
@property
1483-
def dtype(self):
1484-
return np.dtype(self._dtype)
1485-
1486-
def _oindex_get(self, key):
1487-
return np.asarray(self.array.oindex[key], dtype=self.dtype)
1488-
1489-
def _vindex_get(self, key):
1490-
return np.asarray(self.array.vindex[key], dtype=self.dtype)
1491-
1492-
def __getitem__(self, key) -> np.ndarray:
1493-
return np.asarray(self.array[key], dtype=self.dtype)
1494-
1495-
1496-
class LiteralTimedelta64Coder(VariableCoder):
1497-
"""Code np.timedelta64 values."""
1498-
1499-
def encode(self, variable: Variable, name: T_Name = None) -> Variable:
1500-
if np.issubdtype(variable.data.dtype, np.timedelta64):
1501-
dims, data, attrs, encoding = unpack_for_encoding(variable)
1502-
resolution, _ = np.datetime_data(variable.dtype)
1503-
dtype = f"timedelta64[{resolution}]"
1504-
units = _numpy_dtype_to_netcdf_timeunit(variable.dtype)
1505-
safe_setitem(attrs, "dtype", dtype, name=name)
1506-
safe_setitem(attrs, "units", units, name=name)
1507-
# Remove dtype encoding if it exists to prevent it from interfering
1508-
# downstream in NonStringCoder.
1509-
encoding.pop("dtype", None)
1510-
data = duck_array_ops.astype(data, dtype=np.int64, copy=True)
1511-
return Variable(dims, data, attrs, encoding, fastpath=True)
1512-
else:
1513-
return variable
1514-
1515-
def decode(self, variable: Variable, name: T_Name = None) -> Variable:
1516-
if has_timedelta64_encoding_dtype(variable.attrs):
1517-
dims, data, attrs, encoding = unpack_for_decoding(variable)
1518-
dtype = pop_to(attrs, encoding, "dtype", name=name)
1519-
pop_to(attrs, encoding, "units", name=name)
1520-
dtype = np.dtype(dtype)
1521-
resolution, _ = np.datetime_data(dtype)
1522-
if resolution not in typing.get_args(PDDatetimeUnitOptions):
1523-
raise ValueError(
1524-
f"Following pandas, xarray only supports decoding to "
1525-
f"timedelta64 values with a resolution of 's', 'ms', "
1526-
f"'us', or 'ns'. Encoded values have a resolution of "
1527-
f"{resolution!r}."
1528-
)
1529-
data = Timedelta64TypeArray(data, dtype)
15301461
return Variable(dims, data, attrs, encoding, fastpath=True)
15311462
else:
15321463
return variable

xarray/conventions.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import numpy as np
1010

1111
from xarray.coders import CFDatetimeCoder, CFTimedeltaCoder
12-
from xarray.coding import strings, times, variables
12+
from xarray.coding import strings, variables
1313
from xarray.coding.variables import SerializationWarning, pop_to
1414
from xarray.core import indexing
1515
from xarray.core.common import (
@@ -92,7 +92,6 @@ def encode_cf_variable(
9292
for coder in [
9393
CFDatetimeCoder(),
9494
CFTimedeltaCoder(),
95-
times.LiteralTimedelta64Coder(),
9695
variables.CFScaleOffsetCoder(),
9796
variables.CFMaskCoder(),
9897
variables.NativeEnumCoder(),
@@ -243,7 +242,6 @@ def decode_cf_variable(
243242
original_dtype = var.dtype
244243

245244
var = variables.BooleanCoder().decode(var)
246-
var = times.LiteralTimedelta64Coder().decode(var)
247245

248246
dimensions, data, attributes, encoding = variables.unpack_for_decoding(var)
249247

0 commit comments

Comments
 (0)