Skip to content

Support chunks in HashChecker #452

@pmeier

Description

@pmeier

Imagine I have a really large file that I want to read in chunks to avoid memory overflow:

import io
from torchdata.datapipes.iter import IterableWrapper, StreamReader

dp = IterableWrapper([("really_large_file.txt", io.BytesIO(b"foo\nbar\nbaz\n"))])
dp = StreamReader(dp, chunk=4)

for data in dp:
    print(data)
('really_large_file.txt', b'foo\n')
('really_large_file.txt', b'bar\n')
('really_large_file.txt', b'baz\n')

Now, it might be useful to also check the hash of the file. Naively, one could simply attach a

dp = HashChecker(
    dp,
    hash_dict={"really_large_file.txt": "268a5059001855fef30b4f95f82044ed"},
    hash_type="md5",
)

to the datapipe. Unfortunately this leads to a checksum error. This happens because if the input is a bytes, it will taken as the sole item for the hash computation:

https://github.com/pytorch/data/blob/13b574c80e8732744fee6ab9cb7e35b5afc34a3c/torchdata/datapipes/iter/util/hashchecker.py#L73-L76

In contrast, if the input is a stream, it will be iterated and fully used for the computation:

https://github.com/pytorch/data/blob/13b574c80e8732744fee6ab9cb7e35b5afc34a3c/torchdata/datapipes/iter/util/hashchecker.py#L79-L82

Thus, placing the HashChecker before the StreamReader gives the wanted behavior here:

import io
from torchdata.datapipes.iter import IterableWrapper, HashChecker, StreamReader

dp = IterableWrapper([("really_large_file.txt", io.BytesIO(b"foo\nbar\nbaz\n"))])
dp = HashChecker(
    dp,
    hash_dict={"really_large_file.txt": "268a5059001855fef30b4f95f82044ed"},
    hash_type="md5",
)
dp = StreamReader(dp, chunk=4)

for data in dp:
    print(data)

However, this has several downsides:

  1. If the stream is not seekable, e.g. a HTTP response, there is nothing left for the StreamReader to read after the HashChecker is finished.
  2. We can't control how the stream is iterated. As the code comment implies, __iter__ is chosen since it is a common interface for all streams. However, the chunks returned by it have to be separated by a b"\n". Thus, when iterating over arbitrary binary streams we might read the whole file at once, which defeats the chunked behavior we want.
  3. We read from the stream twice since the data read by the HashChecker is not cached anywhere and StreamReader has to do it all over again.

Since the hash_func can be updated, would it be possible to introduce a cache based on the file name in case we encounter bytes? Something along the lines of

dp = iter(self.source_datapipe)
for file_name, data in dp:
    hash_func = ...

    if isinstance(data, (str, bytes, bytearray)):
        if isinstance(data, str):
            data = data.decode()
        hash_func.update(data)

        for file_name_, data_ in dp:
            if file_name_ != file_name:
                break

            if isinstance(data, str):
                data = data.decode()
            hash_func.update(data)
    else:
        ...

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions