Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 4 additions & 10 deletions stackstac/nodata_reader.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import Optional, Tuple, Type, Union, cast
from typing import Tuple, Type, Union, cast
import re

import numpy as np
from rasterio.windows import Window

from .reader_protocol import Reader

State = Tuple[np.dtype, Optional[Union[int, float]]]
State = Tuple[np.dtype, Union[int, float]]


class NodataReader:
Expand All @@ -17,7 +17,7 @@ def __init__(
self,
*,
dtype: np.dtype,
fill_value: Optional[Union[int, float]] = None,
fill_value: Union[int, float],
**kwargs,
) -> None:
self.dtype = dtype
Expand All @@ -36,13 +36,7 @@ def __setstate__(self, state: State) -> None:
self.dtype, self.fill_value = state


def nodata_for_window(
window: Window, fill_value: Optional[Union[int, float]], dtype: np.dtype
):
assert (
fill_value is not None
), "Trying to convert an exception to nodata, but `fill_value` is None"

def nodata_for_window(window: Window, fill_value: Union[int, float], dtype: np.dtype):
height = cast(int, window.height)
width = cast(int, window.width)
# Argument of type "tuple[_T@attrib, _T@attrib]" cannot be assigned to parameter "shape" of type "_ShapeLike"
Expand Down
15 changes: 10 additions & 5 deletions stackstac/reader_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,17 @@ class Reader(Pickleable, Protocol):
Protocol for a thread-safe, lazily-loaded object for reading data from a single-band STAC asset.
"""

fill_value: Union[int, float]
dtype: np.dtype

def __init__(
self,
*,
url: str,
spec: RasterSpec,
resampling: Resampling,
dtype: np.dtype,
fill_value: Optional[Union[int, float]],
fill_value: Union[int, float],
rescale: bool,
gdal_env: Optional[LayeredEnv],
errors_as_nodata: Tuple[Exception, ...] = (),
Expand Down Expand Up @@ -113,13 +116,15 @@ class FakeReader:
or inherent to the dask graph.
"""

def __init__(self, *, url: str, spec: RasterSpec, **kwargs) -> None:
def __init__(
self, *, dtype: np.dtype, fill_value: Union[int, float], **kwargs
) -> None:
pass
# self.url = url
# self.spec = spec
self.dtype = dtype
self.fill_value = fill_value

def read(self, window: Window, **kwargs) -> np.ndarray:
return np.random.random((window.height, window.width))
return np.random.random((window.height, window.width)).astype(self.dtype)

def close(self) -> None:
pass
Expand Down
6 changes: 2 additions & 4 deletions stackstac/rio_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ class PickleState(TypedDict):
spec: RasterSpec
resampling: Resampling
dtype: np.dtype
fill_value: Optional[Union[int, float]]
fill_value: Union[int, float]
rescale: bool
gdal_env: Optional[LayeredEnv]
errors_as_nodata: Tuple[Exception, ...]
Expand All @@ -302,7 +302,7 @@ def __init__(
spec: RasterSpec,
resampling: Resampling,
dtype: np.dtype,
fill_value: Optional[Union[int, float]],
fill_value: Union[int, float],
rescale: bool,
gdal_env: Optional[LayeredEnv] = None,
errors_as_nodata: Tuple[Exception, ...] = (),
Expand Down Expand Up @@ -407,8 +407,6 @@ def read(self, window: Window, **kwargs) -> np.ndarray:

result = result.astype(self.dtype, copy=False)
result = np.ma.filled(result, fill_value=self.fill_value)
# ^ NOTE: if `self.fill_value` was None, rasterio set the masked array's fill value to the
# nodata value of the band, which `np.ma.filled` will then use.
return result

def close(self) -> None:
Expand Down
12 changes: 5 additions & 7 deletions stackstac/stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def stack(
resampling: Resampling = Resampling.nearest,
chunksize: int = 1024,
dtype: np.dtype = np.dtype("float64"),
fill_value: Optional[Union[int, float]] = np.nan,
fill_value: Union[int, float] = np.nan,
rescale: bool = True,
sortby_date: Literal["asc", "desc", False] = "asc",
xy_coords: Literal["center", "topleft", False] = "topleft",
Expand Down Expand Up @@ -192,9 +192,7 @@ def stack(
don't set it here---instead, call ``.chunk`` on the DataArray to re-chunk it.
dtype:
The NumPy data type of the output array. Default: ``float64``. Must be a data type
that's compatible with ``fill_value``. Note that if ``fill_value`` is None, whatever nodata
value is set in each asset's file will be used, so that value needs to be compatible
with ``dtype`` as well.
that's compatible with ``fill_value``.
fill_value:
Value to fill nodata/masked pixels with. Default: ``np.nan``.

Expand Down Expand Up @@ -249,7 +247,7 @@ def stack(
errors_as_nodata:
Exception patterns to ignore when opening datasets or reading data.
Exceptions matching the pattern will be logged as warnings, and just
produce nodata (``fill_value``). A non-None ``fill_value`` is required when using this.
produce nodata (``fill_value``).

The exception patterns should be instances of an Exception type to catch,
where ``str(exception_pattern)`` is a regex pattern to match against
Expand Down Expand Up @@ -282,9 +280,9 @@ def stack(
if sortby_date is not False:
plain_items = sorted(
plain_items,
key=lambda item: item["properties"].get("datetime", ""),
key=lambda item: item["properties"].get("datetime", "") or "",
reverse=sortby_date == "desc",
) # type: ignore
)

asset_table, spec, asset_ids, plain_items = prepare_items(
plain_items,
Expand Down
25 changes: 12 additions & 13 deletions stackstac/to_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,12 @@ def items_to_dask(
chunksize: int,
resampling: Resampling = Resampling.nearest,
dtype: np.dtype = np.dtype("float64"),
fill_value: Optional[Union[int, float]] = np.nan,
fill_value: Union[int, float] = np.nan,
rescale: bool = True,
reader: Type[Reader] = AutoParallelRioReader,
gdal_env: Optional[LayeredEnv] = None,
errors_as_nodata: Tuple[Exception, ...] = (),
) -> da.Array:
if fill_value is None and errors_as_nodata:
raise ValueError(
"A non-None `fill_value` is required when using `errors_as_nodata`. "
"If an exception occurs, we need to know what to use as the nodata value, "
"since there may not be an open dataset to infer it from."
)
errors_as_nodata = errors_as_nodata or () # be sure it's not None

if fill_value is not None and not np.can_cast(fill_value, dtype):
Expand Down Expand Up @@ -114,17 +108,18 @@ def asset_entry_to_reader_and_window(
spec: RasterSpec,
resampling: Resampling,
dtype: np.dtype,
fill_value: Optional[Union[int, float]],
fill_value: Union[int, float],
rescale: bool,
gdal_env: Optional[LayeredEnv],
errors_as_nodata: Tuple[Exception, ...],
reader: Type[ReaderT],
) -> Optional[Tuple[ReaderT, windows.Window]]:
) -> Tuple[ReaderT, windows.Window] | np.ndarray:
asset_entry = asset_entry[0, 0]
# ^ because dask adds extra outer dims in `from_array`
url = asset_entry["url"]
if url is None:
return None
# Signifies empty value
return np.array(fill_value, dtype)

asset_bounds: Bbox = asset_entry["bounds"]
asset_window = windows.from_bounds(*asset_bounds, transform=spec.transform)
Expand Down Expand Up @@ -159,11 +154,11 @@ def asset_entry_to_reader_and_window(


def fetch_raster_window(
asset_entry: Optional[Tuple[Reader, windows.Window]],
asset_entry: Tuple[ReaderT, windows.Window] | np.ndarray,
slices: Tuple[slice, ...],
) -> np.ndarray:
current_window = windows.Window.from_slices(*slices)
if asset_entry is not None:
if isinstance(asset_entry, tuple):
reader, asset_window = asset_entry

# check that the window we're fetching overlaps with the asset
Expand All @@ -172,7 +167,11 @@ def fetch_raster_window(
data = reader.read(current_window)

return data[None, None]
fill_arr = np.array(reader.fill_value, reader.dtype)
else:
fill_arr: np.ndarray = asset_entry

# no dataset, or we didn't overlap it: return empty data.
# use the broadcast trick for even fewer memz
return np.broadcast_to(np.nan, (1, 1) + windows.shape(current_window))
return np.broadcast_to(fill_arr, (1, 1) + windows.shape(current_window))