4
4
import asyncio
5
5
import json
6
6
import logging
7
- from typing import Any , Dict , Literal , Optional , Union , AsyncIterator , Iterator , List
7
+ from typing import Any , Dict , Literal , Optional , Union , AsyncIterator , List
8
8
from zarr .v3 .abc .metadata import Metadata
9
9
10
10
from zarr .v3 .array import AsyncArray , Array
@@ -46,11 +46,11 @@ def to_bytes(self) -> Dict[str, bytes]:
46
46
return {ZARR_JSON : json .dumps (self .to_dict ()).encode ()}
47
47
else :
48
48
return {
49
- ZGROUP_JSON : self . zarr_format ,
49
+ ZGROUP_JSON : json . dumps ({ " zarr_format" : 2 }). encode () ,
50
50
ZATTRS_JSON : json .dumps (self .attributes ).encode (),
51
51
}
52
52
53
- def __init__ (self , attributes : Dict [str , Any ] = None , zarr_format : Literal [2 , 3 ] = 3 ):
53
+ def __init__ (self , attributes : Optional [ Dict [str , Any ] ] = None , zarr_format : Literal [2 , 3 ] = 3 ):
54
54
attributes_parsed = parse_attributes (attributes )
55
55
zarr_format_parsed = parse_zarr_format (zarr_format )
56
56
@@ -104,7 +104,7 @@ async def open(
104
104
zarr_format : Literal [2 , 3 ] = 3 ,
105
105
) -> AsyncGroup :
106
106
store_path = make_store_path (store )
107
- zarr_json_bytes = await (store_path / ZARR_JSON ).get_async ()
107
+ zarr_json_bytes = await (store_path / ZARR_JSON ).get ()
108
108
assert zarr_json_bytes is not None
109
109
110
110
# TODO: consider trying to autodiscover the zarr-format here
@@ -139,7 +139,7 @@ def from_dict(
139
139
store_path : StorePath ,
140
140
data : Dict [str , Any ],
141
141
runtime_configuration : RuntimeConfiguration ,
142
- ) -> Group :
142
+ ) -> AsyncGroup :
143
143
group = cls (
144
144
metadata = GroupMetadata .from_dict (data ),
145
145
store_path = store_path ,
@@ -168,10 +168,12 @@ async def getitem(
168
168
zarr_json = json .loads (zarr_json_bytes )
169
169
if zarr_json ["node_type" ] == "group" :
170
170
return type (self ).from_dict (store_path , zarr_json , self .runtime_configuration )
171
- if zarr_json ["node_type" ] == "array" :
171
+ elif zarr_json ["node_type" ] == "array" :
172
172
return AsyncArray .from_dict (
173
173
store_path , zarr_json , runtime_configuration = self .runtime_configuration
174
174
)
175
+ else :
176
+ raise ValueError (f"unexpected node_type: { zarr_json ['node_type' ]} " )
175
177
elif self .metadata .zarr_format == 2 :
176
178
# Q: how do we like optimistically fetching .zgroup, .zarray, and .zattrs?
177
179
# This guarantees that we will always make at least one extra request to the store
@@ -271,7 +273,7 @@ def __repr__(self):
271
273
async def nchildren (self ) -> int :
272
274
raise NotImplementedError
273
275
274
- async def children (self ) -> AsyncIterator [AsyncArray , AsyncGroup ]:
276
+ async def children (self ) -> AsyncIterator [Union [ AsyncArray , AsyncGroup ] ]:
275
277
raise NotImplementedError
276
278
277
279
async def contains (self , child : str ) -> bool :
@@ -381,8 +383,12 @@ async def update_attributes_async(self, new_attributes: Dict[str, Any]) -> Group
381
383
new_metadata = replace (self .metadata , attributes = new_attributes )
382
384
383
385
# Write new metadata
384
- await (self .store_path / ZARR_JSON ).set_async (new_metadata .to_bytes ())
385
- return replace (self , metadata = new_metadata )
386
+ to_save = new_metadata .to_bytes ()
387
+ awaitables = [(self .store_path / key ).set (value ) for key , value in to_save .items ()]
388
+ await asyncio .gather (* awaitables )
389
+
390
+ async_group = replace (self ._async_group , metadata = new_metadata )
391
+ return replace (self , _async_group = async_group )
386
392
387
393
@property
388
394
def metadata (self ) -> GroupMetadata :
@@ -396,34 +402,38 @@ def attrs(self) -> Attributes:
396
402
def info (self ):
397
403
return self ._async_group .info
398
404
405
+ @property
406
+ def store_path (self ) -> StorePath :
407
+ return self ._async_group .store_path
408
+
399
409
def update_attributes (self , new_attributes : Dict [str , Any ]):
400
410
self ._sync (self ._async_group .update_attributes (new_attributes ))
401
411
return self
402
412
403
413
@property
404
414
def nchildren (self ) -> int :
405
- return self ._sync (self ._async_group .nchildren )
415
+ return self ._sync (self ._async_group .nchildren () )
406
416
407
417
@property
408
- def children (self ) -> List [Array , Group ]:
409
- _children = self ._sync_iter (self ._async_group .children )
418
+ def children (self ) -> List [Union [ Array , Group ] ]:
419
+ _children = self ._sync_iter (self ._async_group .children () )
410
420
return [Array (obj ) if isinstance (obj , AsyncArray ) else Group (obj ) for obj in _children ]
411
421
412
422
def __contains__ (self , child ) -> bool :
413
423
return self ._sync (self ._async_group .contains (child ))
414
424
415
- def group_keys (self ) -> Iterator [str ]:
416
- return self ._sync_iter (self ._async_group .group_keys )
425
+ def group_keys (self ) -> List [str ]:
426
+ return self ._sync_iter (self ._async_group .group_keys () )
417
427
418
428
def groups (self ) -> List [Group ]:
419
429
# TODO: in v2 this was a generator that return key: Group
420
- return [Group (obj ) for obj in self ._sync_iter (self ._async_group .groups )]
430
+ return [Group (obj ) for obj in self ._sync_iter (self ._async_group .groups () )]
421
431
422
432
def array_keys (self ) -> List [str ]:
423
- return self ._sync_iter (self ._async_group .array_keys )
433
+ return self ._sync_iter (self ._async_group .array_keys () )
424
434
425
435
def arrays (self ) -> List [Array ]:
426
- return [Array (obj ) for obj in self ._sync_iter (self ._async_group .arrays )]
436
+ return [Array (obj ) for obj in self ._sync_iter (self ._async_group .arrays () )]
427
437
428
438
def tree (self , expand = False , level = None ) -> Any :
429
439
return self ._sync (self ._async_group .tree (expand = expand , level = level ))
0 commit comments