From 9d155b974d1089445f56143635862c4270dafd7b Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 10 Oct 2022 10:18:33 +0200 Subject: [PATCH] make _setup_fill_arg serializable --- torchvision/prototype/transforms/_utils.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index db1ff4b7b6f..a3980fa2154 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -1,6 +1,6 @@ +import functools import numbers from collections import defaultdict - from typing import Any, Callable, Dict, Sequence, Tuple, Type, Union import PIL.Image @@ -43,13 +43,19 @@ def _check_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> None: raise TypeError("Got inappropriate fill arg") +def _default_fill(fill: FillType) -> FillType: + return fill + + def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillType]: _check_fill_arg(fill) if isinstance(fill, dict): return fill - return defaultdict(lambda: fill) # type: ignore[return-value, arg-type] + # This weird looking construct only exists, since `lambda`'s cannot be serialized by pickle. + # If it were possible, we could replace this with `defaultdict(lambda: fill)` + return defaultdict(functools.partial(_default_fill, fill)) def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None: