Skip to content

Commit 4c98c2b

Browse files
github-actions[bot]pquentinmiguelgrinberg
authored
DSL: preserve the skip_empty setting in to_dict() recursive serializations (#3041) (#3045)
* Try reproducing DSL issue 1577 * better attempt to reproduce * preserve skip_empty setting in recursive serializations --------- (cherry picked from commit 4761d56) Co-authored-by: Quentin Pradet <[email protected]> Co-authored-by: Miguel Grinberg <[email protected]>
1 parent 23c0aaa commit 4c98c2b

File tree

5 files changed

+59
-25
lines changed

5 files changed

+59
-25
lines changed

elasticsearch/dsl/field.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,16 @@ def __init__(
119119
def __getitem__(self, subfield: str) -> "Field":
120120
return cast(Field, self._params.get("fields", {})[subfield])
121121

122-
def _serialize(self, data: Any) -> Any:
122+
def _serialize(self, data: Any, skip_empty: bool) -> Any:
123123
return data
124124

125+
def _safe_serialize(self, data: Any, skip_empty: bool) -> Any:
126+
try:
127+
return self._serialize(data, skip_empty)
128+
except TypeError:
129+
# older method signature, without skip_empty
130+
return self._serialize(data) # type: ignore[call-arg]
131+
125132
def _deserialize(self, data: Any) -> Any:
126133
return data
127134

@@ -133,10 +140,16 @@ def empty(self) -> Optional[Any]:
133140
return AttrList([])
134141
return self._empty()
135142

136-
def serialize(self, data: Any) -> Any:
143+
def serialize(self, data: Any, skip_empty: bool = True) -> Any:
137144
if isinstance(data, (list, AttrList, tuple)):
138-
return list(map(self._serialize, cast(Iterable[Any], data)))
139-
return self._serialize(data)
145+
return list(
146+
map(
147+
self._safe_serialize,
148+
cast(Iterable[Any], data),
149+
[skip_empty] * len(data),
150+
)
151+
)
152+
return self._safe_serialize(data, skip_empty)
140153

141154
def deserialize(self, data: Any) -> Any:
142155
if isinstance(data, (list, AttrList, tuple)):
@@ -186,7 +199,7 @@ def _deserialize(self, data: Any) -> Range["_SupportsComparison"]:
186199
data = {k: self._core_field.deserialize(v) for k, v in data.items()} # type: ignore[union-attr]
187200
return Range(data)
188201

189-
def _serialize(self, data: Any) -> Optional[Dict[str, Any]]:
202+
def _serialize(self, data: Any, skip_empty: bool) -> Optional[Dict[str, Any]]:
190203
if data is None:
191204
return None
192205
if not isinstance(data, collections.abc.Mapping):
@@ -550,7 +563,7 @@ def _deserialize(self, data: Any) -> "InnerDoc":
550563
return self._wrap(data)
551564

552565
def _serialize(
553-
self, data: Optional[Union[Dict[str, Any], "InnerDoc"]]
566+
self, data: Optional[Union[Dict[str, Any], "InnerDoc"]], skip_empty: bool
554567
) -> Optional[Dict[str, Any]]:
555568
if data is None:
556569
return None
@@ -559,7 +572,7 @@ def _serialize(
559572
if isinstance(data, collections.abc.Mapping):
560573
return data
561574

562-
return data.to_dict()
575+
return data.to_dict(skip_empty=skip_empty)
563576

564577
def clean(self, data: Any) -> Any:
565578
data = super().clean(data)
@@ -768,7 +781,7 @@ def clean(self, data: str) -> str:
768781
def _deserialize(self, data: Any) -> bytes:
769782
return base64.b64decode(data)
770783

771-
def _serialize(self, data: Any) -> Optional[str]:
784+
def _serialize(self, data: Any, skip_empty: bool) -> Optional[str]:
772785
if data is None:
773786
return None
774787
return base64.b64encode(data).decode()
@@ -2619,7 +2632,7 @@ def _deserialize(self, data: Any) -> Union["IPv4Address", "IPv6Address"]:
26192632
# the ipaddress library for pypy only accepts unicode.
26202633
return ipaddress.ip_address(unicode(data))
26212634

2622-
def _serialize(self, data: Any) -> Optional[str]:
2635+
def _serialize(self, data: Any, skip_empty: bool) -> Optional[str]:
26232636
if data is None:
26242637
return None
26252638
return str(data)
@@ -3367,7 +3380,7 @@ def __init__(
33673380
def _deserialize(self, data: Any) -> "Query":
33683381
return Q(data) # type: ignore[no-any-return]
33693382

3370-
def _serialize(self, data: Any) -> Optional[Dict[str, Any]]:
3383+
def _serialize(self, data: Any, skip_empty: bool) -> Optional[Dict[str, Any]]:
33713384
if data is None:
33723385
return None
33733386
return data.to_dict() # type: ignore[no-any-return]

elasticsearch/dsl/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,7 @@ def to_dict(self, skip_empty: bool = True) -> Dict[str, Any]:
603603
# if this is a mapped field,
604604
f = self.__get_field(k)
605605
if f and f._coerce:
606-
v = f.serialize(v)
606+
v = f.serialize(v, skip_empty=skip_empty)
607607

608608
# if someone assigned AttrList, unwrap it
609609
if isinstance(v, AttrList):

test_elasticsearch/test_dsl/test_integration/_async/test_document.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -630,15 +630,22 @@ async def test_can_save_to_different_index(
630630
async def test_save_without_skip_empty_will_include_empty_fields(
631631
async_write_client: AsyncElasticsearch,
632632
) -> None:
633-
test_repo = Repository(field_1=[], field_2=None, field_3={}, meta={"id": 42})
633+
test_repo = Repository(
634+
field_1=[], field_2=None, field_3={}, owner={"name": None}, meta={"id": 42}
635+
)
634636
assert await test_repo.save(index="test-document", skip_empty=False)
635637

636638
assert_doc_equals(
637639
{
638640
"found": True,
639641
"_index": "test-document",
640642
"_id": "42",
641-
"_source": {"field_1": [], "field_2": None, "field_3": {}},
643+
"_source": {
644+
"field_1": [],
645+
"field_2": None,
646+
"field_3": {},
647+
"owner": {"name": None},
648+
},
642649
},
643650
await async_write_client.get(index="test-document", id=42),
644651
)

test_elasticsearch/test_dsl/test_integration/_sync/test_document.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -624,15 +624,22 @@ def test_can_save_to_different_index(
624624
def test_save_without_skip_empty_will_include_empty_fields(
625625
write_client: Elasticsearch,
626626
) -> None:
627-
test_repo = Repository(field_1=[], field_2=None, field_3={}, meta={"id": 42})
627+
test_repo = Repository(
628+
field_1=[], field_2=None, field_3={}, owner={"name": None}, meta={"id": 42}
629+
)
628630
assert test_repo.save(index="test-document", skip_empty=False)
629631

630632
assert_doc_equals(
631633
{
632634
"found": True,
633635
"_index": "test-document",
634636
"_id": "42",
635-
"_source": {"field_1": [], "field_2": None, "field_3": {}},
637+
"_source": {
638+
"field_1": [],
639+
"field_2": None,
640+
"field_3": {},
641+
"owner": {"name": None},
642+
},
636643
},
637644
write_client.get(index="test-document", id=42),
638645
)

utils/templates/field.py.tpl

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,16 @@ class Field(DslBase):
119119
def __getitem__(self, subfield: str) -> "Field":
120120
return cast(Field, self._params.get("fields", {})[subfield])
121121

122-
def _serialize(self, data: Any) -> Any:
122+
def _serialize(self, data: Any, skip_empty: bool) -> Any:
123123
return data
124124

125+
def _safe_serialize(self, data: Any, skip_empty: bool) -> Any:
126+
try:
127+
return self._serialize(data, skip_empty)
128+
except TypeError:
129+
# older method signature, without skip_empty
130+
return self._serialize(data) # type: ignore[call-arg]
131+
125132
def _deserialize(self, data: Any) -> Any:
126133
return data
127134

@@ -133,10 +140,10 @@ class Field(DslBase):
133140
return AttrList([])
134141
return self._empty()
135142

136-
def serialize(self, data: Any) -> Any:
143+
def serialize(self, data: Any, skip_empty: bool = True) -> Any:
137144
if isinstance(data, (list, AttrList, tuple)):
138-
return list(map(self._serialize, cast(Iterable[Any], data)))
139-
return self._serialize(data)
145+
return list(map(self._safe_serialize, cast(Iterable[Any], data), [skip_empty] * len(data)))
146+
return self._safe_serialize(data, skip_empty)
140147

141148
def deserialize(self, data: Any) -> Any:
142149
if isinstance(data, (list, AttrList, tuple)):
@@ -186,7 +193,7 @@ class RangeField(Field):
186193
data = {k: self._core_field.deserialize(v) for k, v in data.items()} # type: ignore[union-attr]
187194
return Range(data)
188195

189-
def _serialize(self, data: Any) -> Optional[Dict[str, Any]]:
196+
def _serialize(self, data: Any, skip_empty: bool) -> Optional[Dict[str, Any]]:
190197
if data is None:
191198
return None
192199
if not isinstance(data, collections.abc.Mapping):
@@ -318,7 +325,7 @@ class {{ k.name }}({{ k.parent }}):
318325
return self._wrap(data)
319326

320327
def _serialize(
321-
self, data: Optional[Union[Dict[str, Any], "InnerDoc"]]
328+
self, data: Optional[Union[Dict[str, Any], "InnerDoc"]], skip_empty: bool
322329
) -> Optional[Dict[str, Any]]:
323330
if data is None:
324331
return None
@@ -327,7 +334,7 @@ class {{ k.name }}({{ k.parent }}):
327334
if isinstance(data, collections.abc.Mapping):
328335
return data
329336

330-
return data.to_dict()
337+
return data.to_dict(skip_empty=skip_empty)
331338

332339
def clean(self, data: Any) -> Any:
333340
data = super().clean(data)
@@ -433,7 +440,7 @@ class {{ k.name }}({{ k.parent }}):
433440
# the ipaddress library for pypy only accepts unicode.
434441
return ipaddress.ip_address(unicode(data))
435442

436-
def _serialize(self, data: Any) -> Optional[str]:
443+
def _serialize(self, data: Any, skip_empty: bool) -> Optional[str]:
437444
if data is None:
438445
return None
439446
return str(data)
@@ -448,7 +455,7 @@ class {{ k.name }}({{ k.parent }}):
448455
def _deserialize(self, data: Any) -> bytes:
449456
return base64.b64decode(data)
450457

451-
def _serialize(self, data: Any) -> Optional[str]:
458+
def _serialize(self, data: Any, skip_empty: bool) -> Optional[str]:
452459
if data is None:
453460
return None
454461
return base64.b64encode(data).decode()
@@ -458,7 +465,7 @@ class {{ k.name }}({{ k.parent }}):
458465
def _deserialize(self, data: Any) -> "Query":
459466
return Q(data) # type: ignore[no-any-return]
460467

461-
def _serialize(self, data: Any) -> Optional[Dict[str, Any]]:
468+
def _serialize(self, data: Any, skip_empty: bool) -> Optional[Dict[str, Any]]:
462469
if data is None:
463470
return None
464471
return data.to_dict() # type: ignore[no-any-return]

0 commit comments

Comments
 (0)