From 70e5de08a0b0707fc609968b0a9dafc500773665 Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Thu, 16 Dec 2021 20:49:39 -0700 Subject: [PATCH] No `fill_value=None`; use fill value out-of-bounds We were previously always returning NaN arrays for chunks that were entirely out of bounds, or had no asset. Now, we respect the user-specified `fill_value` and dtype. This means `fill_value=None` is no longer supported, which makes things more straightforward anyway. --- stackstac/nodata_reader.py | 14 ++++---------- stackstac/reader_protocol.py | 15 ++++++++++----- stackstac/rio_reader.py | 6 ++---- stackstac/stack.py | 12 +++++------- stackstac/to_dask.py | 25 ++++++++++++------------- 5 files changed, 33 insertions(+), 39 deletions(-) diff --git a/stackstac/nodata_reader.py b/stackstac/nodata_reader.py index 0466ec1..e453797 100644 --- a/stackstac/nodata_reader.py +++ b/stackstac/nodata_reader.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Type, Union, cast +from typing import Tuple, Type, Union, cast import re import numpy as np @@ -6,7 +6,7 @@ from .reader_protocol import Reader -State = Tuple[np.dtype, Optional[Union[int, float]]] +State = Tuple[np.dtype, Union[int, float]] class NodataReader: @@ -17,7 +17,7 @@ def __init__( self, *, dtype: np.dtype, - fill_value: Optional[Union[int, float]] = None, + fill_value: Union[int, float], **kwargs, ) -> None: self.dtype = dtype @@ -36,13 +36,7 @@ def __setstate__(self, state: State) -> None: self.dtype, self.fill_value = state -def nodata_for_window( - window: Window, fill_value: Optional[Union[int, float]], dtype: np.dtype -): - assert ( - fill_value is not None - ), "Trying to convert an exception to nodata, but `fill_value` is None" - +def nodata_for_window(window: Window, fill_value: Union[int, float], dtype: np.dtype): height = cast(int, window.height) width = cast(int, window.width) # Argument of type "tuple[_T@attrib, _T@attrib]" cannot be assigned to parameter "shape" of type "_ShapeLike" diff --git a/stackstac/reader_protocol.py b/stackstac/reader_protocol.py index 294c784..ad4f457 100644 --- a/stackstac/reader_protocol.py +++ b/stackstac/reader_protocol.py @@ -26,6 +26,9 @@ class Reader(Pickleable, Protocol): Protocol for a thread-safe, lazily-loaded object for reading data from a single-band STAC asset. """ + fill_value: Union[int, float] + dtype: np.dtype + def __init__( self, *, @@ -33,7 +36,7 @@ def __init__( spec: RasterSpec, resampling: Resampling, dtype: np.dtype, - fill_value: Optional[Union[int, float]], + fill_value: Union[int, float], rescale: bool, gdal_env: Optional[LayeredEnv], errors_as_nodata: Tuple[Exception, ...] = (), @@ -113,13 +116,15 @@ class FakeReader: or inherent to the dask graph. """ - def __init__(self, *, url: str, spec: RasterSpec, **kwargs) -> None: + def __init__( + self, *, dtype: np.dtype, fill_value: Union[int, float], **kwargs + ) -> None: pass - # self.url = url - # self.spec = spec + self.dtype = dtype + self.fill_value = fill_value def read(self, window: Window, **kwargs) -> np.ndarray: - return np.random.random((window.height, window.width)) + return np.random.random((window.height, window.width)).astype(self.dtype) def close(self) -> None: pass diff --git a/stackstac/rio_reader.py b/stackstac/rio_reader.py index 483436d..73e6bf2 100644 --- a/stackstac/rio_reader.py +++ b/stackstac/rio_reader.py @@ -279,7 +279,7 @@ class PickleState(TypedDict): spec: RasterSpec resampling: Resampling dtype: np.dtype - fill_value: Optional[Union[int, float]] + fill_value: Union[int, float] rescale: bool gdal_env: Optional[LayeredEnv] errors_as_nodata: Tuple[Exception, ...] @@ -302,7 +302,7 @@ def __init__( spec: RasterSpec, resampling: Resampling, dtype: np.dtype, - fill_value: Optional[Union[int, float]], + fill_value: Union[int, float], rescale: bool, gdal_env: Optional[LayeredEnv] = None, errors_as_nodata: Tuple[Exception, ...] = (), @@ -407,8 +407,6 @@ def read(self, window: Window, **kwargs) -> np.ndarray: result = result.astype(self.dtype, copy=False) result = np.ma.filled(result, fill_value=self.fill_value) - # ^ NOTE: if `self.fill_value` was None, rasterio set the masked array's fill value to the - # nodata value of the band, which `np.ma.filled` will then use. return result def close(self) -> None: diff --git a/stackstac/stack.py b/stackstac/stack.py index 8300a49..7494990 100644 --- a/stackstac/stack.py +++ b/stackstac/stack.py @@ -30,7 +30,7 @@ def stack( resampling: Resampling = Resampling.nearest, chunksize: int = 1024, dtype: np.dtype = np.dtype("float64"), - fill_value: Optional[Union[int, float]] = np.nan, + fill_value: Union[int, float] = np.nan, rescale: bool = True, sortby_date: Literal["asc", "desc", False] = "asc", xy_coords: Literal["center", "topleft", False] = "topleft", @@ -192,9 +192,7 @@ def stack( don't set it here---instead, call ``.chunk`` on the DataArray to re-chunk it. dtype: The NumPy data type of the output array. Default: ``float64``. Must be a data type - that's compatible with ``fill_value``. Note that if ``fill_value`` is None, whatever nodata - value is set in each asset's file will be used, so that value needs to be compatible - with ``dtype`` as well. + that's compatible with ``fill_value``. fill_value: Value to fill nodata/masked pixels with. Default: ``np.nan``. @@ -249,7 +247,7 @@ def stack( errors_as_nodata: Exception patterns to ignore when opening datasets or reading data. Exceptions matching the pattern will be logged as warnings, and just - produce nodata (``fill_value``). A non-None ``fill_value`` is required when using this. + produce nodata (``fill_value``). The exception patterns should be instances of an Exception type to catch, where ``str(exception_pattern)`` is a regex pattern to match against @@ -282,9 +280,9 @@ def stack( if sortby_date is not False: plain_items = sorted( plain_items, - key=lambda item: item["properties"].get("datetime", ""), + key=lambda item: item["properties"].get("datetime", "") or "", reverse=sortby_date == "desc", - ) # type: ignore + ) asset_table, spec, asset_ids, plain_items = prepare_items( plain_items, diff --git a/stackstac/to_dask.py b/stackstac/to_dask.py index 897f105..58a788c 100644 --- a/stackstac/to_dask.py +++ b/stackstac/to_dask.py @@ -24,18 +24,12 @@ def items_to_dask( chunksize: int, resampling: Resampling = Resampling.nearest, dtype: np.dtype = np.dtype("float64"), - fill_value: Optional[Union[int, float]] = np.nan, + fill_value: Union[int, float] = np.nan, rescale: bool = True, reader: Type[Reader] = AutoParallelRioReader, gdal_env: Optional[LayeredEnv] = None, errors_as_nodata: Tuple[Exception, ...] = (), ) -> da.Array: - if fill_value is None and errors_as_nodata: - raise ValueError( - "A non-None `fill_value` is required when using `errors_as_nodata`. " - "If an exception occurs, we need to know what to use as the nodata value, " - "since there may not be an open dataset to infer it from." - ) errors_as_nodata = errors_as_nodata or () # be sure it's not None if fill_value is not None and not np.can_cast(fill_value, dtype): @@ -114,17 +108,18 @@ def asset_entry_to_reader_and_window( spec: RasterSpec, resampling: Resampling, dtype: np.dtype, - fill_value: Optional[Union[int, float]], + fill_value: Union[int, float], rescale: bool, gdal_env: Optional[LayeredEnv], errors_as_nodata: Tuple[Exception, ...], reader: Type[ReaderT], -) -> Optional[Tuple[ReaderT, windows.Window]]: +) -> Tuple[ReaderT, windows.Window] | np.ndarray: asset_entry = asset_entry[0, 0] # ^ because dask adds extra outer dims in `from_array` url = asset_entry["url"] if url is None: - return None + # Signifies empty value + return np.array(fill_value, dtype) asset_bounds: Bbox = asset_entry["bounds"] asset_window = windows.from_bounds(*asset_bounds, transform=spec.transform) @@ -159,11 +154,11 @@ def asset_entry_to_reader_and_window( def fetch_raster_window( - asset_entry: Optional[Tuple[Reader, windows.Window]], + asset_entry: Tuple[ReaderT, windows.Window] | np.ndarray, slices: Tuple[slice, ...], ) -> np.ndarray: current_window = windows.Window.from_slices(*slices) - if asset_entry is not None: + if isinstance(asset_entry, tuple): reader, asset_window = asset_entry # check that the window we're fetching overlaps with the asset @@ -172,7 +167,11 @@ def fetch_raster_window( data = reader.read(current_window) return data[None, None] + fill_arr = np.array(reader.fill_value, reader.dtype) + else: + fill_arr: np.ndarray = asset_entry # no dataset, or we didn't overlap it: return empty data. # use the broadcast trick for even fewer memz - return np.broadcast_to(np.nan, (1, 1) + windows.shape(current_window)) + return np.broadcast_to(fill_arr, (1, 1) + windows.shape(current_window)) +