From 17ed0224130da55920c44941e755b1994146dc7f Mon Sep 17 00:00:00 2001 From: Joseph Hamman Date: Mon, 1 Jul 2024 21:49:22 -0700 Subject: [PATCH 1/6] test: check that store, array, and group classes are serializable w/ pickle and can be dependably roundtripped --- src/zarr/abc/store.py | 5 ++++ src/zarr/buffer.py | 5 ++++ src/zarr/store/memory.py | 19 +++++++++++++++ src/zarr/store/remote.py | 10 ++++++++ src/zarr/testing/store.py | 14 +++++++++++ tests/v3/test_array.py | 38 +++++++++++++++++++++++++++++- tests/v3/test_group.py | 24 +++++++++++++++++++ tests/v3/test_store/test_remote.py | 2 +- 8 files changed, 115 insertions(+), 2 deletions(-) diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index 14566dfed2..a1e798e74a 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -28,6 +28,11 @@ def _check_writable(self) -> None: if not self.writeable: raise ValueError("store mode does not support writing") + @abstractmethod + def __eq__(self, value: object) -> bool: + """Equality comparison.""" + ... + @abstractmethod async def get( self, diff --git a/src/zarr/buffer.py b/src/zarr/buffer.py index 86f9b53477..1569d54ef8 100644 --- a/src/zarr/buffer.py +++ b/src/zarr/buffer.py @@ -255,6 +255,11 @@ def __add__(self, other: Buffer) -> Self: np.concatenate((np.asanyarray(self._data), np.asanyarray(other_array))) ) + def __eq__(self, other: object) -> bool: + # Note: this was needed to support comparing MemoryStore instances with Buffer values in them + # if/when we stopped putting buffers in memory stores, this can be removed + return isinstance(other, type(self)) and self.to_bytes() == other.to_bytes() + class NDBuffer: """An n-dimensional memory block diff --git a/src/zarr/store/memory.py b/src/zarr/store/memory.py index 7b73330b6c..7309d70236 100644 --- a/src/zarr/store/memory.py +++ b/src/zarr/store/memory.py @@ -29,6 +29,25 @@ def __str__(self) -> str: def __repr__(self) -> str: return f"MemoryStore({str(self)!r})" + def __eq__(self, other: object) -> bool: + return ( + isinstance(other, type(self)) + and self._store_dict == other._store_dict + and self.mode == other.mode + ) + + def __setstate__(self, state: tuple[MutableMapping[str, Buffer], OpenMode]) -> None: + # warnings.warn( + # f"unpickling {type(self)} may produce unexpected behavior and should only be used for testing and/or debugging" + # ) + self._store_dict, self._mode = state + + def __getstate__(self) -> tuple[MutableMapping[str, Buffer], OpenMode]: + # warnings.warn( + # f"pickling {type(self)} may produce unexpected behavior and should only be used for testing and/or debugging" + # ) + return self._store_dict, self._mode + async def get( self, key: str, diff --git a/src/zarr/store/remote.py b/src/zarr/store/remote.py index e0b69cac50..a67217a319 100644 --- a/src/zarr/store/remote.py +++ b/src/zarr/store/remote.py @@ -52,6 +52,7 @@ def __init__( """ super().__init__(mode=mode) + self._storage_options = storage_options if isinstance(url, str): self._url = url.rstrip("/") self._fs, _path = fsspec.url_to_fs(url, **storage_options) @@ -81,6 +82,15 @@ def __str__(self) -> str: def __repr__(self) -> str: return f"" + def __eq__(self, other: object) -> bool: + return ( + isinstance(other, type(self)) + and self.path == other.path + and self.mode == other.mode + and self._url == other._url + # and self._storage_options == other._storage_options # FIXME: this isn't working for some reason + ) + async def get( self, key: str, diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index 9c37ce0434..534f348913 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -1,3 +1,4 @@ +import pickle from typing import Any, Generic, TypeVar import pytest @@ -42,6 +43,19 @@ def test_store_type(self, store: S) -> None: assert isinstance(store, Store) assert isinstance(store, self.store_cls) + def test_store_eq(self, store: S, store_kwargs: dict[str, Any]) -> None: + # check self equality + assert store == store + + # check store equality with same inputs + # asserting this is important for being able to compare (de)serialized stores + store2 = self.store_cls(**store_kwargs) + assert store == store2 + + def test_serizalizable_store(self, store: S) -> None: + foo = pickle.dumps(store) + assert pickle.loads(foo) == store + def test_store_mode(self, store: S, store_kwargs: dict[str, Any]) -> None: assert store.mode == "w", store.mode assert store.writeable diff --git a/tests/v3/test_array.py b/tests/v3/test_array.py index 203cfbf860..d95cb9861a 100644 --- a/tests/v3/test_array.py +++ b/tests/v3/test_array.py @@ -1,6 +1,9 @@ +import pickle + +import numpy as np import pytest -from zarr.array import Array +from zarr.array import Array, AsyncArray from zarr.common import ZarrFormat from zarr.group import Group from zarr.store import LocalStore, MemoryStore @@ -34,3 +37,36 @@ def test_array_name_properties_with_group( assert spam.path == "bar/spam" assert spam.name == "/bar/spam" assert spam.basename == "spam" + + +@pytest.mark.parametrize("store", ("memory", "local"), indirect=["store"]) +@pytest.mark.parametrize("zarr_format", (2, 3)) +async def test_serizalizable_async_array( + store: LocalStore | MemoryStore, zarr_format: ZarrFormat +) -> None: + expected = await AsyncArray.create( + store=store, shape=(100,), chunks=(10,), zarr_format=zarr_format, dtype="i4" + ) + # await expected.setitems(list(range(100))) + + p = pickle.dumps(expected) + actual = pickle.loads(p) + + assert actual == expected + # np.testing.assert_array_equal(await actual.getitem(slice(None)), await expected.getitem(slice(None))) + # TODO: uncomment the parts of this test that will be impacted by the config/prototype changes in flight + + +@pytest.mark.parametrize("store", ("memory", "local"), indirect=["store"]) +@pytest.mark.parametrize("zarr_format", (2, 3)) +def test_serizalizable_sync_array(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None: + expected = Array.create( + store=store, shape=(100,), chunks=(10,), zarr_format=zarr_format, dtype="i4" + ) + expected[:] = list(range(100)) + + p = pickle.dumps(expected) + actual = pickle.loads(p) + + assert actual == expected + np.testing.assert_array_equal(actual[:], expected[:]) diff --git a/tests/v3/test_group.py b/tests/v3/test_group.py index e11af748b3..0f9596a2d9 100644 --- a/tests/v3/test_group.py +++ b/tests/v3/test_group.py @@ -1,5 +1,6 @@ from __future__ import annotations +import pickle from typing import TYPE_CHECKING, Any from zarr.array import AsyncArray @@ -391,3 +392,26 @@ def test_group_name_properties(store: LocalStore | MemoryStore, zarr_format: Zar assert bar.path == "foo/bar" assert bar.name == "/foo/bar" assert bar.basename == "bar" + + +@pytest.mark.parametrize("store", ("memory", "local"), indirect=["store"]) +@pytest.mark.parametrize("zarr_format", (2, 3)) +async def test_serizalizable_async_group( + store: LocalStore | MemoryStore, zarr_format: ZarrFormat +) -> None: + expected = await AsyncGroup.create( + store=store, attributes={"foo": 999}, zarr_format=zarr_format + ) + p = pickle.dumps(expected) + actual = pickle.loads(p) + assert actual == expected + + +@pytest.mark.parametrize("store", ("memory", "local"), indirect=["store"]) +@pytest.mark.parametrize("zarr_format", (2, 3)) +def test_serizalizable_sync_group(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None: + expected = Group.create(store=store, attributes={"foo": 999}, zarr_format=zarr_format) + p = pickle.dumps(expected) + actual = pickle.loads(p) + + assert actual == expected diff --git a/tests/v3/test_store/test_remote.py b/tests/v3/test_store/test_remote.py index 0dc399be42..a308cfe161 100644 --- a/tests/v3/test_store/test_remote.py +++ b/tests/v3/test_store/test_remote.py @@ -104,7 +104,7 @@ def store_kwargs(self, request) -> dict[str, str | bool]: anon = False mode = "w" if request.param == "use_upath": - return {"mode": mode, "url": UPath(url, endpoint_url=endpoint_url, anon=anon)} + return {"url": UPath(url, endpoint_url=endpoint_url, anon=anon), "mode": mode} elif request.param == "use_str": return {"url": url, "mode": mode, "anon": anon, "endpoint_url": endpoint_url} From 275d6faf66d5a69b3566a5d3706de740f3ec5ded Mon Sep 17 00:00:00 2001 From: Joseph Hamman Date: Tue, 2 Jul 2024 08:41:51 -0700 Subject: [PATCH 2/6] raise if MemoryStore is pickled --- src/zarr/store/memory.py | 10 ++-------- tests/v3/test_array.py | 4 ++-- tests/v3/test_group.py | 4 ++-- tests/v3/test_store/test_memory.py | 12 ++++++++++++ 4 files changed, 18 insertions(+), 12 deletions(-) diff --git a/src/zarr/store/memory.py b/src/zarr/store/memory.py index 7309d70236..bb8fe673b5 100644 --- a/src/zarr/store/memory.py +++ b/src/zarr/store/memory.py @@ -37,16 +37,10 @@ def __eq__(self, other: object) -> bool: ) def __setstate__(self, state: tuple[MutableMapping[str, Buffer], OpenMode]) -> None: - # warnings.warn( - # f"unpickling {type(self)} may produce unexpected behavior and should only be used for testing and/or debugging" - # ) - self._store_dict, self._mode = state + raise NotImplementedError(f"{type(self)} cannot be pickled") def __getstate__(self) -> tuple[MutableMapping[str, Buffer], OpenMode]: - # warnings.warn( - # f"pickling {type(self)} may produce unexpected behavior and should only be used for testing and/or debugging" - # ) - return self._store_dict, self._mode + raise NotImplementedError(f"{type(self)} cannot be pickled") async def get( self, diff --git a/tests/v3/test_array.py b/tests/v3/test_array.py index d95cb9861a..dca28312b4 100644 --- a/tests/v3/test_array.py +++ b/tests/v3/test_array.py @@ -39,7 +39,7 @@ def test_array_name_properties_with_group( assert spam.basename == "spam" -@pytest.mark.parametrize("store", ("memory", "local"), indirect=["store"]) +@pytest.mark.parametrize("store", ("local",), indirect=["store"]) @pytest.mark.parametrize("zarr_format", (2, 3)) async def test_serizalizable_async_array( store: LocalStore | MemoryStore, zarr_format: ZarrFormat @@ -57,7 +57,7 @@ async def test_serizalizable_async_array( # TODO: uncomment the parts of this test that will be impacted by the config/prototype changes in flight -@pytest.mark.parametrize("store", ("memory", "local"), indirect=["store"]) +@pytest.mark.parametrize("store", ("local",), indirect=["store"]) @pytest.mark.parametrize("zarr_format", (2, 3)) def test_serizalizable_sync_array(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None: expected = Array.create( diff --git a/tests/v3/test_group.py b/tests/v3/test_group.py index 0f9596a2d9..d3a7c9de51 100644 --- a/tests/v3/test_group.py +++ b/tests/v3/test_group.py @@ -394,7 +394,7 @@ def test_group_name_properties(store: LocalStore | MemoryStore, zarr_format: Zar assert bar.basename == "bar" -@pytest.mark.parametrize("store", ("memory", "local"), indirect=["store"]) +@pytest.mark.parametrize("store", ("local",), indirect=["store"]) @pytest.mark.parametrize("zarr_format", (2, 3)) async def test_serizalizable_async_group( store: LocalStore | MemoryStore, zarr_format: ZarrFormat @@ -407,7 +407,7 @@ async def test_serizalizable_async_group( assert actual == expected -@pytest.mark.parametrize("store", ("memory", "local"), indirect=["store"]) +@pytest.mark.parametrize("store", ("local",), indirect=["store"]) @pytest.mark.parametrize("zarr_format", (2, 3)) def test_serizalizable_sync_group(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None: expected = Group.create(store=store, attributes={"foo": 999}, zarr_format=zarr_format) diff --git a/tests/v3/test_store/test_memory.py b/tests/v3/test_store/test_memory.py index 96b8b19e2c..863ecadf44 100644 --- a/tests/v3/test_store/test_memory.py +++ b/tests/v3/test_store/test_memory.py @@ -1,5 +1,7 @@ from __future__ import annotations +import pickle + import pytest from zarr.buffer import Buffer @@ -38,3 +40,13 @@ def test_store_supports_partial_writes(self, store: MemoryStore) -> None: def test_list_prefix(self, store: MemoryStore) -> None: assert True + + def test_serizalizable_store(self, store: MemoryStore) -> None: + with pytest.raises(NotImplementedError): + store.__getstate__() + + with pytest.raises(NotImplementedError): + store.__setstate__({}) + + with pytest.raises(NotImplementedError): + pickle.dumps(store) From 938a0bbb19d4a58dd1ae433990171d698bb041a4 Mon Sep 17 00:00:00 2001 From: Joe Hamman Date: Fri, 9 Aug 2024 13:54:25 -0700 Subject: [PATCH 3/6] Apply suggestions from code review Co-authored-by: Davis Bennett --- tests/v3/test_array.py | 2 +- tests/v3/test_group.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/v3/test_array.py b/tests/v3/test_array.py index dca28312b4..2391c1523a 100644 --- a/tests/v3/test_array.py +++ b/tests/v3/test_array.py @@ -59,7 +59,7 @@ async def test_serizalizable_async_array( @pytest.mark.parametrize("store", ("local",), indirect=["store"]) @pytest.mark.parametrize("zarr_format", (2, 3)) -def test_serizalizable_sync_array(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None: +def test_serizalizable_sync_array(store: LocalStore, zarr_format: ZarrFormat) -> None: expected = Array.create( store=store, shape=(100,), chunks=(10,), zarr_format=zarr_format, dtype="i4" ) diff --git a/tests/v3/test_group.py b/tests/v3/test_group.py index d3a7c9de51..b42b2c22bb 100644 --- a/tests/v3/test_group.py +++ b/tests/v3/test_group.py @@ -409,7 +409,7 @@ async def test_serizalizable_async_group( @pytest.mark.parametrize("store", ("local",), indirect=["store"]) @pytest.mark.parametrize("zarr_format", (2, 3)) -def test_serizalizable_sync_group(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None: +def test_serizalizable_sync_group(store: LocalStore, zarr_format: ZarrFormat) -> None: expected = Group.create(store=store, attributes={"foo": 999}, zarr_format=zarr_format) p = pickle.dumps(expected) actual = pickle.loads(p) From 4fa132e19bc6cce9d1ef18040904ba5e5554fc1a Mon Sep 17 00:00:00 2001 From: Joseph Hamman Date: Tue, 13 Aug 2024 16:13:58 -0700 Subject: [PATCH 4/6] fix typos --- tests/v3/test_array.py | 4 ++-- tests/v3/test_group.py | 6 ++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/v3/test_array.py b/tests/v3/test_array.py index 2f2dcb975b..8912ae1f06 100644 --- a/tests/v3/test_array.py +++ b/tests/v3/test_array.py @@ -141,7 +141,7 @@ def test_array_v3_fill_value(store: MemoryStore, fill_value: int, dtype_str: str @pytest.mark.parametrize("store", ("local",), indirect=["store"]) @pytest.mark.parametrize("zarr_format", (2, 3)) -async def test_serizalizable_async_array( +async def test_serializable_async_array( store: LocalStore | MemoryStore, zarr_format: ZarrFormat ) -> None: expected = await AsyncArray.create( @@ -159,7 +159,7 @@ async def test_serizalizable_async_array( @pytest.mark.parametrize("store", ("local",), indirect=["store"]) @pytest.mark.parametrize("zarr_format", (2, 3)) -def test_serizalizable_sync_array(store: LocalStore, zarr_format: ZarrFormat) -> None: +def test_serializable_sync_array(store: LocalStore, zarr_format: ZarrFormat) -> None: expected = Array.create( store=store, shape=(100,), chunks=(10,), zarr_format=zarr_format, dtype="i4" ) diff --git a/tests/v3/test_group.py b/tests/v3/test_group.py index dc3a599ba4..4af8f52aca 100644 --- a/tests/v3/test_group.py +++ b/tests/v3/test_group.py @@ -658,9 +658,7 @@ async def test_asyncgroup_update_attributes( @pytest.mark.parametrize("store", ("local",), indirect=["store"]) @pytest.mark.parametrize("zarr_format", (2, 3)) -async def test_serizalizable_async_group( - store: LocalStore | MemoryStore, zarr_format: ZarrFormat -) -> None: +async def test_serializable_async_group(store: LocalStore, zarr_format: ZarrFormat) -> None: expected = await AsyncGroup.create( store=store, attributes={"foo": 999}, zarr_format=zarr_format ) @@ -671,7 +669,7 @@ async def test_serizalizable_async_group( @pytest.mark.parametrize("store", ("local",), indirect=["store"]) @pytest.mark.parametrize("zarr_format", (2, 3)) -def test_serizalizable_sync_group(store: LocalStore, zarr_format: ZarrFormat) -> None: +def test_serializable_sync_group(store: LocalStore, zarr_format: ZarrFormat) -> None: expected = Group.create(store=store, attributes={"foo": 999}, zarr_format=zarr_format) p = pickle.dumps(expected) actual = pickle.loads(p) From 767df056c47c3973533323489ab5dd56c4b24b82 Mon Sep 17 00:00:00 2001 From: Joseph Hamman Date: Mon, 19 Aug 2024 21:37:59 -0700 Subject: [PATCH 5/6] new buffer __eq__ --- src/zarr/core/buffer.py | 7 ++++--- tests/v3/test_array.py | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/zarr/core/buffer.py b/src/zarr/core/buffer.py index a74fae3a43..a32081d0d7 100644 --- a/src/zarr/core/buffer.py +++ b/src/zarr/core/buffer.py @@ -262,9 +262,10 @@ def __add__(self, other: Buffer) -> Self: ) def __eq__(self, other: object) -> bool: - # Note: this was needed to support comparing MemoryStore instances with Buffer values in them - # if/when we stopped putting buffers in memory stores, this can be removed - return isinstance(other, type(self)) and self.to_bytes() == other.to_bytes() + # Another Buffer class can override this to choose a more efficient path + return isinstance(other, Buffer) and np.array_equal( + self.as_numpy_array(), other.as_numpy_array() + ) class NDBuffer: diff --git a/tests/v3/test_array.py b/tests/v3/test_array.py index 539f2e527e..bf2ead28c8 100644 --- a/tests/v3/test_array.py +++ b/tests/v3/test_array.py @@ -8,7 +8,7 @@ from zarr.core.common import ZarrFormat from zarr.errors import ContainsArrayError, ContainsGroupError from zarr.store import LocalStore, MemoryStore -from zarr.store.core import StorePath +from zarr.store.common import StorePath @pytest.mark.parametrize("store", ("local", "memory"), indirect=["store"]) From c80ee2b5cd5084b6c192409776452a7aec5df253 Mon Sep 17 00:00:00 2001 From: Joseph Hamman Date: Sat, 14 Sep 2024 09:03:17 -0700 Subject: [PATCH 6/6] pickle support for zip store --- src/zarr/store/zip.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/zarr/store/zip.py b/src/zarr/store/zip.py index 15473aa674..ea31ad934a 100644 --- a/src/zarr/store/zip.py +++ b/src/zarr/store/zip.py @@ -5,7 +5,7 @@ import time import zipfile from pathlib import Path -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Any, Literal from zarr.abc.store import Store from zarr.core.buffer import Buffer, BufferPrototype @@ -68,7 +68,7 @@ def __init__( self.compression = compression self.allowZip64 = allowZip64 - async def _open(self) -> None: + def _sync_open(self) -> None: if self._is_open: raise ValueError("store is already open") @@ -83,6 +83,17 @@ async def _open(self) -> None: self._is_open = True + async def _open(self) -> None: + self._sync_open() + + def __getstate__(self) -> tuple[Path, ZipStoreAccessModeLiteral, int, bool]: + return self.path, self._zmode, self.compression, self.allowZip64 + + def __setstate__(self, state: Any) -> None: + self.path, self._zmode, self.compression, self.allowZip64 = state + self._is_open = False + self._sync_open() + def close(self) -> None: super().close() with self._lock: