-
Notifications
You must be signed in to change notification settings - Fork 7.2k
Description
It is a common feature request (for example #4991) to be able to disable the decoding when loading a dataset. To solve this we added a decoder keyword argument to the load mechanism (torchvision.prototype.datasets.load(..., decoder=...)). It takes an Optional[Callable] with the following signature:
def my_decoder(buffer: BinaryIO) -> torch.Tensor: ...If it is a callable, it will be passed a buffer from the dataset and the result will be integrated into the sample dictionary. If the decoder is None instead, the buffer will be integrated in the sample dictionary instead, leaving it to the user to decode.
vision/torchvision/prototype/datasets/_builtin/imagenet.py
Lines 132 to 134 in 4cacf5a
| return dict( | |
| path=path, | |
| image=decoder(buffer) if decoder else buffer, |
This works well for images, but already breaks down for videos as discovered in #4838. The issue is that decoding a video results in more information than a single tensor. The tentative plan in #4838 was to change the signature to
def my_decoder(buffer: BinaryIO) -> Dict[str, Any]: ...With this, a decoder can now return arbitrary information, which can be integrated in the top level of the sample dictionary.
Unfortunately, looking ahead, I don't think even this architecture will be sufficient. Two issues came to mind:
- The current signature assumes that there is only one type of payload to decode in a dataset, i.e. images or videos. Other types, for example annotation files stored as
.mat,.xml, or.flo, will always be decoded. Thus, the user can't completely deactivate the decoding after all. Furthermore, they can also not use any custom decoding for these types if need be. - The current signature assumes that all payloads of a single type can be decoded by the same decoder. Counter examples to this are the HD1K optical flow datasets that uses 16bit
.pngimages as annotations which have sub-par support byPillow.
To overcome this, I propose a new architecture that is similar to the RoutedDecoder datapipe. We should have a Decoder class that has a sequence of Handler's (name up for discussion):
class Decoder:
def __init__(
self,
*handlers: Callable[[str, BinaryIO], Optional[str, Any]],
must_decode: bool = True,
):
self.handlers = handlers
self.must_decode = must_decode
def __call__(
self,
path: str,
buffer: BinaryIO,
*,
prefix: str = "",
include_path: bool = True,
) -> Dict[str, Any]:
for handler in self.handlers:
output = handler(path, buffer)
if output is not None:
break
else:
if self.must_decode:
raise RuntimeError(
f"No handler was responsible for decoding the file {path}."
)
output = {(f"{prefix}_" if prefix else "") + "buffer": buffer}
if include_path:
output[(f"{prefix}_" if prefix else "") + "path"] = path
return outputIf called with a path-buffer-pair the decoder iterates through the registered handlers and returns the first valid output. Thus, each handler can determine based on the path if it is responsible for decoding the current buffer. By default, the decoder will bail if no handler decoded the input. This can be relaxed by the must_decode=False flag (name up for discussion), which is a convenient way to have a non-decoder.
We would need to change datasets.load function to
def load(
...,
decoder: Optional[
Union[
Decoder,
Callable[[str, BinaryIO], Optional[str, Any]],
Sequence[Callable[[str, BinaryIO], Optional[str, Any]]],
]
] = ()
):
...
if decoder is None:
decoder = Decoder(must_decode=False)
elif not isinstance(decoder, Decoder):
decoder = Decoder(
*decoder if isinstance(decoder, collections.abc.Sequence) else decoder,
*dataset.info.handlers,
*default_handlers,
)
...By default the user would get the dataset specific handlers as well as the default ones. By supplying custom ones, they would be processed with a higher priority and thus overwriting the default behavior if needs be. If None is passed we get a true non-encoder. Finally, by passing a Decoder instance the user has full control over the behavior.
Within the dataset definition, the call to the decoder would simply look like
path, buffer = data
sample = dict(...)
sample.update(decoder(path, buffer))or, if multiple buffers need to be decoded,
image_data, ann_data = data
image_path, image_buffer = image_data
ann_path, ann_buffer = ann_data
sample = dict()
sample.update(decoder(image_path, image_buffer, prefix="image"))
sample.update(decoder(ann_path, ann_buffer, prefix="ann"))