diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 30a1c588c61..cac5cb1ea8b 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -204,6 +204,15 @@ Bug fixes (:pull:`10352`). By `Spencer Clark `_. - Avoid unsafe casts from float to unsigned int in CFMaskCoder (:issue:`9815`, :pull:`9964`). By ` Elliott Sales de Andrade `_. +- Fix attribute overwriting bug when decoding encoded + :py:class:`numpy.timedelta64` values from disk with a dtype attribute + (:issue:`10468`, :pull:`10469`). By `Spencer Clark + `_. +- Fix default ``"_FillValue"`` dtype coercion bug when encoding + :py:class:`numpy.timedelta64` values to an on-disk format that only supports + 32-bit integers (:issue:`10466`, :pull:`10469`). By `Spencer Clark + `_. + Performance ~~~~~~~~~~~ diff --git a/xarray/coding/times.py b/xarray/coding/times.py index d1cc36558fa..d6567ba4c61 100644 --- a/xarray/coding/times.py +++ b/xarray/coding/times.py @@ -1410,6 +1410,43 @@ def has_timedelta64_encoding_dtype(attrs_or_encoding: dict) -> bool: return isinstance(dtype, str) and dtype.startswith("timedelta64") +def resolve_time_unit_from_attrs_dtype( + attrs_dtype: str, name: T_Name +) -> PDDatetimeUnitOptions: + dtype = np.dtype(attrs_dtype) + resolution, _ = np.datetime_data(dtype) + resolution = cast(NPDatetimeUnitOptions, resolution) + if np.timedelta64(1, resolution) > np.timedelta64(1, "s"): + time_unit = cast(PDDatetimeUnitOptions, "s") + message = ( + f"Following pandas, xarray only supports decoding to timedelta64 " + f"values with a resolution of 's', 'ms', 'us', or 'ns'. Encoded " + f"values for variable {name!r} have a resolution of " + f"{resolution!r}. Attempting to decode to a resolution of 's'. " + f"Note, depending on the encoded values, this may lead to an " + f"OverflowError. Additionally, data will not be identically round " + f"tripped; xarray will choose an encoding dtype of " + f"'timedelta64[s]' when re-encoding." + ) + emit_user_level_warning(message) + elif np.timedelta64(1, resolution) < np.timedelta64(1, "ns"): + time_unit = cast(PDDatetimeUnitOptions, "ns") + message = ( + f"Following pandas, xarray only supports decoding to timedelta64 " + f"values with a resolution of 's', 'ms', 'us', or 'ns'. Encoded " + f"values for variable {name!r} have a resolution of " + f"{resolution!r}. Attempting to decode to a resolution of 'ns'. " + f"Note, depending on the encoded values, this may lead to loss of " + f"precision. Additionally, data will not be identically round " + f"tripped; xarray will choose an encoding dtype of " + f"'timedelta64[ns]' when re-encoding." + ) + emit_user_level_warning(message) + else: + time_unit = cast(PDDatetimeUnitOptions, resolution) + return time_unit + + class CFTimedeltaCoder(VariableCoder): """Coder for CF Timedelta coding. @@ -1430,7 +1467,7 @@ class CFTimedeltaCoder(VariableCoder): def __init__( self, - time_unit: PDDatetimeUnitOptions = "ns", + time_unit: PDDatetimeUnitOptions | None = None, decode_via_units: bool = True, decode_via_dtype: bool = True, ) -> None: @@ -1442,45 +1479,18 @@ def __init__( def encode(self, variable: Variable, name: T_Name = None) -> Variable: if np.issubdtype(variable.data.dtype, np.timedelta64): dims, data, attrs, encoding = unpack_for_encoding(variable) - has_timedelta_dtype = has_timedelta64_encoding_dtype(encoding) - if ("units" in encoding or "dtype" in encoding) and not has_timedelta_dtype: - dtype = encoding.get("dtype", None) - units = encoding.pop("units", None) + dtype = encoding.get("dtype", None) + units = encoding.pop("units", None) - # in the case of packed data we need to encode into - # float first, the correct dtype will be established - # via CFScaleOffsetCoder/CFMaskCoder - if "add_offset" in encoding or "scale_factor" in encoding: - dtype = data.dtype if data.dtype.kind == "f" else "float64" + # in the case of packed data we need to encode into + # float first, the correct dtype will be established + # via CFScaleOffsetCoder/CFMaskCoder + if "add_offset" in encoding or "scale_factor" in encoding: + dtype = data.dtype if data.dtype.kind == "f" else "float64" - else: - resolution, _ = np.datetime_data(variable.dtype) - dtype = np.int64 - attrs_dtype = f"timedelta64[{resolution}]" - units = _numpy_dtype_to_netcdf_timeunit(variable.dtype) - safe_setitem(attrs, "dtype", attrs_dtype, name=name) - # Remove dtype encoding if it exists to prevent it from - # interfering downstream in NonStringCoder. - encoding.pop("dtype", None) - - if any( - k in encoding for k in _INVALID_LITERAL_TIMEDELTA64_ENCODING_KEYS - ): - raise ValueError( - f"Specifying 'add_offset' or 'scale_factor' is not " - f"supported when encoding the timedelta64 values of " - f"variable {name!r} with xarray's new default " - f"timedelta64 encoding approach. To encode {name!r} " - f"with xarray's previous timedelta64 encoding " - f"approach, which supports the 'add_offset' and " - f"'scale_factor' parameters, additionally set " - f"encoding['units'] to a unit of time, e.g. " - f"'seconds'. To proceed with encoding of {name!r} " - f"via xarray's new approach, remove any encoding " - f"entries for 'add_offset' or 'scale_factor'." - ) - if "_FillValue" not in encoding and "missing_value" not in encoding: - encoding["_FillValue"] = np.iinfo(np.int64).min + resolution, _ = np.datetime_data(variable.dtype) + attrs_dtype = f"timedelta64[{resolution}]" + safe_setitem(attrs, "dtype", attrs_dtype, name=name) data, units = encode_cf_timedelta(data, units, dtype) safe_setitem(attrs, "units", units, name=name) @@ -1499,54 +1509,13 @@ def decode(self, variable: Variable, name: T_Name = None) -> Variable: ): dims, data, attrs, encoding = unpack_for_decoding(variable) units = pop_to(attrs, encoding, "units") - if is_dtype_decodable and self.decode_via_dtype: - if any( - k in encoding for k in _INVALID_LITERAL_TIMEDELTA64_ENCODING_KEYS - ): - raise ValueError( - f"Decoding timedelta64 values via dtype is not " - f"supported when 'add_offset', or 'scale_factor' are " - f"present in encoding. Check the encoding parameters " - f"of variable {name!r}." - ) - dtype = pop_to(attrs, encoding, "dtype", name=name) - dtype = np.dtype(dtype) - resolution, _ = np.datetime_data(dtype) - resolution = cast(NPDatetimeUnitOptions, resolution) - if np.timedelta64(1, resolution) > np.timedelta64(1, "s"): - time_unit = cast(PDDatetimeUnitOptions, "s") - dtype = np.dtype("timedelta64[s]") - message = ( - f"Following pandas, xarray only supports decoding to " - f"timedelta64 values with a resolution of 's', 'ms', " - f"'us', or 'ns'. Encoded values for variable {name!r} " - f"have a resolution of {resolution!r}. Attempting to " - f"decode to a resolution of 's'. Note, depending on " - f"the encoded values, this may lead to an " - f"OverflowError. Additionally, data will not be " - f"identically round tripped; xarray will choose an " - f"encoding dtype of 'timedelta64[s]' when re-encoding." - ) - emit_user_level_warning(message) - elif np.timedelta64(1, resolution) < np.timedelta64(1, "ns"): - time_unit = cast(PDDatetimeUnitOptions, "ns") - dtype = np.dtype("timedelta64[ns]") - message = ( - f"Following pandas, xarray only supports decoding to " - f"timedelta64 values with a resolution of 's', 'ms', " - f"'us', or 'ns'. Encoded values for variable {name!r} " - f"have a resolution of {resolution!r}. Attempting to " - f"decode to a resolution of 'ns'. Note, depending on " - f"the encoded values, this may lead to loss of " - f"precision. Additionally, data will not be " - f"identically round tripped; xarray will choose an " - f"encoding dtype of 'timedelta64[ns]' " - f"when re-encoding." - ) - emit_user_level_warning(message) + if is_dtype_decodable: + attrs_dtype = attrs.pop("dtype") + if self.time_unit is None: + time_unit = resolve_time_unit_from_attrs_dtype(attrs_dtype, name) else: - time_unit = cast(PDDatetimeUnitOptions, resolution) - elif self.decode_via_units: + time_unit = self.time_unit + else: if self._emit_decode_timedelta_future_warning: emit_user_level_warning( "In a future version, xarray will not decode " @@ -1564,8 +1533,19 @@ def decode(self, variable: Variable, name: T_Name = None) -> Variable: "'CFTimedeltaCoder' instance.", FutureWarning, ) - dtype = np.dtype(f"timedelta64[{self.time_unit}]") - time_unit = self.time_unit + if self.time_unit is None: + time_unit = cast(PDDatetimeUnitOptions, "ns") + else: + time_unit = self.time_unit + + # Handle edge case that decode_via_dtype=False and + # decode_via_units=True, and timedeltas were encoded with a + # dtype attribute. We need to remove the dtype attribute + # to prevent an error during round tripping. + if has_timedelta_dtype: + attrs.pop("dtype") + + dtype = np.dtype(f"timedelta64[{time_unit}]") transform = partial(decode_cf_timedelta, units=units, time_unit=time_unit) data = lazy_elemwise_func(data, transform, dtype=dtype) return Variable(dims, data, attrs, encoding, fastpath=True) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index f42d2c2c17f..2709e834e68 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -56,6 +56,7 @@ from xarray.conventions import encode_dataset_coordinates from xarray.core import indexing from xarray.core.options import set_options +from xarray.core.types import PDDatetimeUnitOptions from xarray.core.utils import module_available from xarray.namedarray.pycompat import array_type from xarray.tests import ( @@ -642,6 +643,16 @@ def test_roundtrip_timedelta_data(self) -> None: ) as actual: assert_identical(expected, actual) + def test_roundtrip_timedelta_data_via_dtype( + self, time_unit: PDDatetimeUnitOptions + ) -> None: + time_deltas = pd.to_timedelta(["1h", "2h", "NaT"]).as_unit(time_unit) # type: ignore[arg-type, unused-ignore] + expected = Dataset( + {"td": ("td", time_deltas), "td0": time_deltas[0].to_numpy()} + ) + with self.roundtrip(expected) as actual: + assert_identical(expected, actual) + def test_roundtrip_float64_data(self) -> None: expected = Dataset({"x": ("y", np.array([1.0, 2.0, np.pi], dtype="float64"))}) with self.roundtrip(expected) as actual: diff --git a/xarray/tests/test_coding_times.py b/xarray/tests/test_coding_times.py index 65caab1c709..af29716fec0 100644 --- a/xarray/tests/test_coding_times.py +++ b/xarray/tests/test_coding_times.py @@ -20,7 +20,6 @@ ) from xarray.coders import CFDatetimeCoder, CFTimedeltaCoder from xarray.coding.times import ( - _INVALID_LITERAL_TIMEDELTA64_ENCODING_KEYS, _encode_datetime_with_cftime, _netcdf_to_numpy_timeunit, _numpy_to_netcdf_timeunit, @@ -1824,8 +1823,9 @@ def test_encode_cf_timedelta_small_dtype_missing_value(use_dask) -> None: assert_equal(variable, decoded) -_DECODE_TIMEDELTA_TESTS = { +_DECODE_TIMEDELTA_VIA_UNITS_TESTS = { "default": (True, None, np.dtype("timedelta64[ns]"), True), + "decode_timedelta=True": (True, True, np.dtype("timedelta64[ns]"), False), "decode_timedelta=False": (True, False, np.dtype("int64"), False), "inherit-time_unit-from-decode_times": ( CFDatetimeCoder(time_unit="s"), @@ -1856,16 +1856,16 @@ def test_encode_cf_timedelta_small_dtype_missing_value(use_dask) -> None: @pytest.mark.parametrize( ("decode_times", "decode_timedelta", "expected_dtype", "warns"), - list(_DECODE_TIMEDELTA_TESTS.values()), - ids=list(_DECODE_TIMEDELTA_TESTS.keys()), + list(_DECODE_TIMEDELTA_VIA_UNITS_TESTS.values()), + ids=list(_DECODE_TIMEDELTA_VIA_UNITS_TESTS.keys()), ) -def test_decode_timedelta( +def test_decode_timedelta_via_units( decode_times, decode_timedelta, expected_dtype, warns ) -> None: timedeltas = pd.timedelta_range(0, freq="D", periods=3) - encoding = {"units": "days"} - var = Variable(["time"], timedeltas, encoding=encoding) - encoded = conventions.encode_cf_variable(var) + attrs = {"units": "days"} + var = Variable(["time"], timedeltas, encoding=attrs) + encoded = Variable(["time"], np.array([0, 1, 2]), attrs=attrs) if warns: with pytest.warns(FutureWarning, match="decode_timedelta"): decoded = conventions.decode_cf_variable( @@ -1885,6 +1885,57 @@ def test_decode_timedelta( assert decoded.dtype == expected_dtype +_DECODE_TIMEDELTA_VIA_DTYPE_TESTS = { + "default": (True, None, np.dtype("timedelta64[ns]")), + "decode_timedelta=False": (True, False, np.dtype("int64")), + "decode_timedelta=True": (True, True, np.dtype("timedelta64[ns]")), + "inherit-time_unit-from-decode_times": ( + CFDatetimeCoder(time_unit="s"), + None, + np.dtype("timedelta64[s]"), + ), + "set-time_unit-via-CFTimedeltaCoder-decode_times=True": ( + True, + CFTimedeltaCoder(time_unit="s"), + np.dtype("timedelta64[s]"), + ), + "set-time_unit-via-CFTimedeltaCoder-decode_times=False": ( + False, + CFTimedeltaCoder(time_unit="s"), + np.dtype("timedelta64[s]"), + ), + "override-time_unit-from-decode_times": ( + CFDatetimeCoder(time_unit="ns"), + CFTimedeltaCoder(time_unit="s"), + np.dtype("timedelta64[s]"), + ), +} + + +@pytest.mark.parametrize( + ("decode_times", "decode_timedelta", "expected_dtype"), + list(_DECODE_TIMEDELTA_VIA_DTYPE_TESTS.values()), + ids=list(_DECODE_TIMEDELTA_VIA_DTYPE_TESTS.keys()), +) +def test_decode_timedelta_via_dtype( + decode_times, decode_timedelta, expected_dtype +) -> None: + timedeltas = pd.timedelta_range(0, freq="D", periods=3) + encoding = {"units": "days"} + var = Variable(["time"], timedeltas, encoding=encoding) + encoded = conventions.encode_cf_variable(var) + assert encoded.attrs["dtype"] == "timedelta64[ns]" + assert encoded.attrs["units"] == encoding["units"] + decoded = conventions.decode_cf_variable( + "foo", encoded, decode_times=decode_times, decode_timedelta=decode_timedelta + ) + if decode_timedelta is False: + assert_equal(encoded, decoded) + else: + assert_equal(var, decoded) + assert decoded.dtype == expected_dtype + + def test_lazy_decode_timedelta_unexpected_dtype() -> None: attrs = {"units": "seconds"} encoded = Variable(["time"], [0, 0.5, 1], attrs=attrs) @@ -1940,7 +1991,12 @@ def test_duck_array_decode_times(calendar) -> None: def test_decode_timedelta_mask_and_scale( decode_timedelta: bool, mask_and_scale: bool ) -> None: - attrs = {"units": "nanoseconds", "_FillValue": np.int16(-1), "add_offset": 100000.0} + attrs = { + "dtype": "timedelta64[ns]", + "units": "nanoseconds", + "_FillValue": np.int16(-1), + "add_offset": 100000.0, + } encoded = Variable(["time"], np.array([0, -1, 1], "int16"), attrs=attrs) decoded = conventions.decode_cf_variable( "foo", encoded, mask_and_scale=mask_and_scale, decode_timedelta=decode_timedelta @@ -1958,19 +2014,17 @@ def test_decode_floating_point_timedelta_no_serialization_warning() -> None: decoded.load() -def test_literal_timedelta64_coding(time_unit: PDDatetimeUnitOptions) -> None: +def test_timedelta64_coding_via_dtype(time_unit: PDDatetimeUnitOptions) -> None: timedeltas = np.array([0, 1, "NaT"], dtype=f"timedelta64[{time_unit}]") variable = Variable(["time"], timedeltas) - expected_dtype = f"timedelta64[{time_unit}]" expected_units = _numpy_to_netcdf_timeunit(time_unit) encoded = conventions.encode_cf_variable(variable) - assert encoded.attrs["dtype"] == expected_dtype + assert encoded.attrs["dtype"] == f"timedelta64[{time_unit}]" assert encoded.attrs["units"] == expected_units - assert encoded.attrs["_FillValue"] == np.iinfo(np.int64).min decoded = conventions.decode_cf_variable("timedeltas", encoded) - assert decoded.encoding["dtype"] == expected_dtype + assert decoded.encoding["dtype"] == np.dtype("int64") assert decoded.encoding["units"] == expected_units assert_identical(decoded, variable) @@ -1981,7 +2035,7 @@ def test_literal_timedelta64_coding(time_unit: PDDatetimeUnitOptions) -> None: assert reencoded.dtype == encoded.dtype -def test_literal_timedelta_coding_non_pandas_coarse_resolution_warning() -> None: +def test_timedelta_coding_via_dtype_non_pandas_coarse_resolution_warning() -> None: attrs = {"dtype": "timedelta64[D]", "units": "days"} encoded = Variable(["time"], [0, 1, 2], attrs=attrs) with pytest.warns(UserWarning, match="xarray only supports"): @@ -1994,7 +2048,7 @@ def test_literal_timedelta_coding_non_pandas_coarse_resolution_warning() -> None @pytest.mark.xfail(reason="xarray does not recognize picoseconds as time-like") -def test_literal_timedelta_coding_non_pandas_fine_resolution_warning() -> None: +def test_timedelta_coding_via_dtype_non_pandas_fine_resolution_warning() -> None: attrs = {"dtype": "timedelta64[ps]", "units": "picoseconds"} encoded = Variable(["time"], [0, 1000, 2000], attrs=attrs) with pytest.warns(UserWarning, match="xarray only supports"): @@ -2006,17 +2060,16 @@ def test_literal_timedelta_coding_non_pandas_fine_resolution_warning() -> None: assert decoded.dtype == np.dtype("timedelta64[ns]") -@pytest.mark.parametrize("attribute", ["dtype", "units"]) -def test_literal_timedelta_decode_invalid_encoding(attribute) -> None: +def test_timedelta_decode_via_dtype_invalid_encoding() -> None: attrs = {"dtype": "timedelta64[s]", "units": "seconds"} - encoding = {attribute: "foo"} + encoding = {"units": "foo"} encoded = Variable(["time"], [0, 1, 2], attrs=attrs, encoding=encoding) with pytest.raises(ValueError, match="failed to prevent"): conventions.decode_cf_variable("timedeltas", encoded) @pytest.mark.parametrize("attribute", ["dtype", "units"]) -def test_literal_timedelta_encode_invalid_attribute(attribute) -> None: +def test_timedelta_encode_via_dtype_invalid_attribute(attribute) -> None: timedeltas = pd.timedelta_range(0, freq="D", periods=3) attrs = {attribute: "foo"} variable = Variable(["time"], timedeltas, attrs=attrs) @@ -2024,23 +2077,6 @@ def test_literal_timedelta_encode_invalid_attribute(attribute) -> None: conventions.encode_cf_variable(variable) -@pytest.mark.parametrize("invalid_key", _INVALID_LITERAL_TIMEDELTA64_ENCODING_KEYS) -def test_literal_timedelta_encoding_invalid_key_error(invalid_key) -> None: - encoding = {invalid_key: 1.0} - timedeltas = pd.timedelta_range(0, freq="D", periods=3) - variable = Variable(["time"], timedeltas, encoding=encoding) - with pytest.raises(ValueError, match=invalid_key): - conventions.encode_cf_variable(variable) - - -@pytest.mark.parametrize("invalid_key", _INVALID_LITERAL_TIMEDELTA64_ENCODING_KEYS) -def test_literal_timedelta_decoding_invalid_key_error(invalid_key) -> None: - attrs = {invalid_key: 1.0, "dtype": "timedelta64[s]", "units": "seconds"} - variable = Variable(["time"], [0, 1, 2], attrs=attrs) - with pytest.raises(ValueError, match=invalid_key): - conventions.decode_cf_variable("foo", variable) - - @pytest.mark.parametrize( ("decode_via_units", "decode_via_dtype", "attrs", "expect_timedelta64"), [ @@ -2058,12 +2094,6 @@ def test_literal_timedelta_decoding_invalid_key_error(invalid_key) -> None: def test_timedelta_decoding_options( decode_via_units, decode_via_dtype, attrs, expect_timedelta64 ) -> None: - # Note with literal timedelta encoding, we always add a _FillValue, even - # if one is not present in the original encoding parameters, which is why - # we ensure one is defined here when "dtype" is present in attrs. - if "dtype" in attrs: - attrs["_FillValue"] = np.iinfo(np.int64).min - array = np.array([0, 1, 2], dtype=np.dtype("int64")) encoded = Variable(["time"], array, attrs=attrs) @@ -2083,7 +2113,11 @@ def test_timedelta_decoding_options( # Confirm we exactly roundtrip. reencoded = conventions.encode_cf_variable(decoded) - assert_identical(reencoded, encoded) + + expected = encoded.copy() + if "dtype" not in attrs and decode_via_units: + expected.attrs["dtype"] = "timedelta64[s]" + assert_identical(reencoded, expected) def test_timedelta_encoding_explicit_non_timedelta64_dtype() -> None: @@ -2093,20 +2127,21 @@ def test_timedelta_encoding_explicit_non_timedelta64_dtype() -> None: encoded = conventions.encode_cf_variable(variable) assert encoded.attrs["units"] == "days" + assert encoded.attrs["dtype"] == "timedelta64[ns]" assert encoded.dtype == np.dtype("int32") - with pytest.warns(FutureWarning, match="timedelta"): - decoded = conventions.decode_cf_variable("foo", encoded) + decoded = conventions.decode_cf_variable("foo", encoded) assert_identical(decoded, variable) reencoded = conventions.encode_cf_variable(decoded) assert_identical(reencoded, encoded) assert encoded.attrs["units"] == "days" + assert encoded.attrs["dtype"] == "timedelta64[ns]" assert encoded.dtype == np.dtype("int32") @pytest.mark.parametrize("mask_attribute", ["_FillValue", "missing_value"]) -def test_literal_timedelta64_coding_with_mask( +def test_timedelta64_coding_via_dtype_with_mask( time_unit: PDDatetimeUnitOptions, mask_attribute: str ) -> None: timedeltas = np.array([0, 1, "NaT"], dtype=f"timedelta64[{time_unit}]") @@ -2122,7 +2157,7 @@ def test_literal_timedelta64_coding_with_mask( assert encoded[-1] == mask decoded = conventions.decode_cf_variable("timedeltas", encoded) - assert decoded.encoding["dtype"] == expected_dtype + assert decoded.encoding["dtype"] == np.dtype("int64") assert decoded.encoding["units"] == expected_units assert decoded.encoding[mask_attribute] == mask assert np.isnat(decoded[-1]) @@ -2144,7 +2179,7 @@ def test_roundtrip_0size_timedelta(time_unit: PDDatetimeUnitOptions) -> None: assert encoded.dtype == encoding["dtype"] assert encoded.attrs["units"] == encoding["units"] decoded = conventions.decode_cf_variable("foo", encoded, decode_timedelta=True) - assert decoded.dtype == np.dtype("=m8[ns]") + assert decoded.dtype == np.dtype(f"=m8[{time_unit}]") with assert_no_warnings(): decoded.load() assert decoded.dtype == np.dtype("=m8[s]")