diff --git a/dpath/segments.py b/dpath/segments.py index 7a48817..c3c9846 100644 --- a/dpath/segments.py +++ b/dpath/segments.py @@ -1,6 +1,6 @@ from copy import deepcopy from fnmatch import fnmatchcase -from typing import List, Sequence, Tuple, Iterator, Any, Union, Optional, MutableMapping +from typing import Sequence, Tuple, Iterator, Any, Union, Optional, MutableMapping, MutableSequence from dpath import options from dpath.exceptions import InvalidGlob, InvalidKeyName, PathNotFound @@ -81,7 +81,7 @@ def walk(obj, location=()): yield found -def get(obj, segments): +def get(obj, segments: Path): """ Return the value at the path indicated by segments. @@ -92,6 +92,9 @@ def get(obj, segments): if leaf(current): raise PathNotFound(f"Path: {segments}[{i}]") + if isinstance(current, Sequence) and isinstance(segment, str) and segment.isdecimal(): + segment = int(segment) + current = current[segment] return current @@ -254,7 +257,7 @@ def match(segments: Path, glob: Glob): return False -def extend(thing: List, index: int, value=None): +def extend(thing: MutableSequence, index: int, value=None): """ Extend a sequence like thing such that it contains at least index + 1 many elements. The extension values will be None (default). @@ -280,7 +283,7 @@ def extend(thing: List, index: int, value=None): def _default_creator( - current: Union[MutableMapping, List], + current: Union[MutableMapping, Sequence], segments: Sequence[PathSegment], i: int, hints: Sequence[Tuple[PathSegment, type]] = () @@ -294,7 +297,10 @@ def _default_creator( segment = segments[i] length = len(segments) - if isinstance(segment, int): + if isinstance(current, Sequence): + segment = int(segment) + + if isinstance(current, MutableSequence): extend(current, segment) # Infer the type from the hints provided. @@ -308,7 +314,7 @@ def _default_creator( else: segment_next = None - if isinstance(segment_next, int): + if isinstance(segment_next, int) or (isinstance(segment_next, str) and segment_next.isdecimal()): current[segment] = [] else: current[segment] = {} @@ -336,7 +342,7 @@ def set( for (i, segment) in enumerate(segments[:-1]): # If segment is non-int but supposed to be a sequence index - if isinstance(segment, str) and isinstance(current, Sequence) and segment.isdigit(): + if isinstance(segment, str) and isinstance(current, Sequence) and segment.isdecimal(): segment = int(segment) try: @@ -358,7 +364,7 @@ def set( last_segment = segments[-1] # Resolve ambiguity of last segment - if isinstance(last_segment, str) and isinstance(current, Sequence) and last_segment.isdigit(): + if isinstance(last_segment, str) and isinstance(current, Sequence) and last_segment.isdecimal(): last_segment = int(last_segment) if isinstance(last_segment, int): diff --git a/dpath/types.py b/dpath/types.py index 7bf3d2d..c4a4a56 100644 --- a/dpath/types.py +++ b/dpath/types.py @@ -46,7 +46,7 @@ class MergeType(IntFlag): replaces the destination in this situation.""" -PathSegment = Union[int, str] +PathSegment = Union[int, str, bytes] """Type alias for dict path segments where integers are explicitly casted.""" Filter = Callable[[Any], bool] diff --git a/dpath/version.py b/dpath/version.py index 4260069..5dfae46 100644 --- a/dpath/version.py +++ b/dpath/version.py @@ -1 +1 @@ -VERSION = "2.1.3" +VERSION = "2.1.4" diff --git a/tests/test_new.py b/tests/test_new.py index 15b21c6..ac47e7d 100644 --- a/tests/test_new.py +++ b/tests/test_new.py @@ -52,6 +52,27 @@ def test_set_list_with_dict_int_ambiguity(): assert d == expected +def test_int_segment_list_type_check(): + d = {} + dpath.new(d, "a/b/0/c/0", "hello") + assert 'b' in d.get("a", {}) + assert isinstance(d["a"]["b"], list) + assert len(d["a"]["b"]) == 1 + assert 'c' in d["a"]["b"][0] + assert isinstance(d["a"]["b"][0]["c"], list) + assert len(d["a"]["b"][0]["c"]) == 1 + + +def test_int_segment_dict_type_check(): + d = {"a": {"b": {"0": {}}}} + dpath.new(d, "a/b/0/c/0", "hello") + assert "b" in d.get("a", {}) + assert isinstance(d["a"]["b"], dict) + assert '0' in d["a"]["b"] + assert 'c' in d["a"]["b"]["0"] + assert isinstance(d["a"]["b"]["0"]["c"], list) + + def test_set_new_list_path_with_separator(): # This test kills many birds with one stone, forgive me dict = {