Skip to content

Commit 5719019

Browse files
authored
Merge pull request #1129 from effigies/enh/streams
ENH: Add to/from_stream methods and from_url classmethod to SerializableImage
2 parents b38a99b + 82c50ba commit 5719019

File tree

4 files changed

+147
-31
lines changed

4 files changed

+147
-31
lines changed

nibabel/filebasedimages.py

Lines changed: 59 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import io
1212
from copy import deepcopy
13+
from urllib import request
1314
from .fileholders import FileHolder
1415
from .filename_parser import (types_filenames, TypesFilenamesError,
1516
splitext_addext)
@@ -488,7 +489,7 @@ def path_maybe_image(klass, filename, sniff=None, sniff_max=1024):
488489

489490
class SerializableImage(FileBasedImage):
490491
"""
491-
Abstract image class for (de)serializing images to/from byte strings.
492+
Abstract image class for (de)serializing images to/from byte streams/strings.
492493
493494
The class doesn't define any image properties.
494495
@@ -501,6 +502,7 @@ class SerializableImage(FileBasedImage):
501502
classmethods:
502503
503504
* from_bytes(bytestring) - make instance by deserializing a byte string
505+
* from_url(url) - make instance by fetching and deserializing a URL
504506
505507
Loading from byte strings should provide round-trip equivalence:
506508
@@ -538,7 +540,43 @@ class SerializableImage(FileBasedImage):
538540
"""
539541

540542
@classmethod
541-
def from_bytes(klass, bytestring):
543+
def _filemap_from_iobase(klass, io_obj: io.IOBase):
544+
"""For single-file image types, make a file map with the correct key"""
545+
if len(klass.files_types) > 1:
546+
raise NotImplementedError(
547+
"(de)serialization is undefined for multi-file images"
548+
)
549+
return klass.make_file_map({klass.files_types[0][0]: io_obj})
550+
551+
@classmethod
552+
def from_stream(klass, io_obj: io.IOBase):
553+
"""Load image from readable IO stream
554+
555+
Convert to BytesIO to enable seeking, if input stream is not seekable
556+
557+
Parameters
558+
----------
559+
io_obj : IOBase object
560+
Readable stream
561+
"""
562+
if not io_obj.seekable():
563+
io_obj = io.BytesIO(io_obj.read())
564+
return klass.from_file_map(klass._filemap_from_iobase(io_obj))
565+
566+
def to_stream(self, io_obj: io.IOBase, **kwargs):
567+
"""Save image to writable IO stream
568+
569+
Parameters
570+
----------
571+
io_obj : IOBase object
572+
Writable stream
573+
\*\*kwargs : keyword arguments
574+
Keyword arguments that may be passed to ``img.to_file_map()``
575+
"""
576+
self.to_file_map(self._filemap_from_iobase(io_obj), **kwargs)
577+
578+
@classmethod
579+
def from_bytes(klass, bytestring: bytes):
542580
""" Construct image from a byte string
543581
544582
Class method
@@ -548,13 +586,9 @@ def from_bytes(klass, bytestring):
548586
bstring : bytes
549587
Byte string containing the on-disk representation of an image
550588
"""
551-
if len(klass.files_types) > 1:
552-
raise NotImplementedError("from_bytes is undefined for multi-file images")
553-
bio = io.BytesIO(bytestring)
554-
file_map = klass.make_file_map({'image': bio, 'header': bio})
555-
return klass.from_file_map(file_map)
589+
return klass.from_stream(io.BytesIO(bytestring))
556590

557-
def to_bytes(self, **kwargs):
591+
def to_bytes(self, **kwargs) -> bytes:
558592
r""" Return a ``bytes`` object with the contents of the file that would
559593
be written if the image were saved.
560594
@@ -568,9 +602,22 @@ def to_bytes(self, **kwargs):
568602
bytes
569603
Serialized image
570604
"""
571-
if len(self.__class__.files_types) > 1:
572-
raise NotImplementedError("to_bytes() is undefined for multi-file images")
573605
bio = io.BytesIO()
574-
file_map = self.make_file_map({'image': bio, 'header': bio})
575-
self.to_file_map(file_map, **kwargs)
606+
self.to_stream(bio, **kwargs)
576607
return bio.getvalue()
608+
609+
@classmethod
610+
def from_url(klass, url, timeout=5):
611+
"""Retrieve and load an image from a URL
612+
613+
Class method
614+
615+
Parameters
616+
----------
617+
url : str or urllib.request.Request object
618+
URL of file to retrieve
619+
timeout : float, optional
620+
Time (in seconds) to wait for a response
621+
"""
622+
response = request.urlopen(url, timeout=timeout)
623+
return klass.from_stream(response)

nibabel/tests/test_filebasedimages.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import warnings
66

77
import numpy as np
8+
import pytest
89

910
from ..filebasedimages import FileBasedHeader, FileBasedImage, SerializableImage
1011

@@ -127,3 +128,24 @@ def __init__(self, seq=None):
127128
hdr4 = H.from_header(None)
128129
assert isinstance(hdr4, H)
129130
assert hdr4.a_list == []
131+
132+
133+
class MultipartNumpyImage(FBNumpyImage):
134+
# We won't actually try to write these out, just need to test an edge case
135+
files_types = (('header', '.hdr'), ('image', '.npy'))
136+
137+
138+
class SerializableMPNumpyImage(MultipartNumpyImage, SerializableImage):
139+
pass
140+
141+
142+
def test_multifile_stream_failure():
143+
shape = (2, 3, 4)
144+
arr = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
145+
img = SerializableMPNumpyImage(arr)
146+
with pytest.raises(NotImplementedError):
147+
img.to_bytes()
148+
img = SerializableNumpyImage(arr)
149+
bstr = img.to_bytes()
150+
with pytest.raises(NotImplementedError):
151+
SerializableMPNumpyImage.from_bytes(bstr)

nibabel/tests/test_image_api.py

Lines changed: 65 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import warnings
2727
from functools import partial
2828
from itertools import product
29+
import io
2930
import pathlib
3031

3132
import numpy as np
@@ -523,34 +524,41 @@ def validate_affine_deprecated(self, imaker, params):
523524
img.get_affine()
524525

525526

526-
class SerializeMixin(object):
527-
def validate_to_bytes(self, imaker, params):
527+
class SerializeMixin:
528+
def validate_to_from_stream(self, imaker, params):
528529
img = imaker()
529-
serialized = img.to_bytes()
530-
with InTemporaryDirectory():
531-
fname = 'img' + self.standard_extension
532-
img.to_filename(fname)
533-
with open(fname, 'rb') as fobj:
534-
file_contents = fobj.read()
535-
assert serialized == file_contents
530+
klass = getattr(self, 'klass', img.__class__)
531+
stream = io.BytesIO()
532+
img.to_stream(stream)
536533

537-
def validate_from_bytes(self, imaker, params):
534+
rt_img = klass.from_stream(stream)
535+
assert self._header_eq(img.header, rt_img.header)
536+
assert np.array_equal(img.get_fdata(), rt_img.get_fdata())
537+
538+
def validate_file_stream_equivalence(self, imaker, params):
538539
img = imaker()
539540
klass = getattr(self, 'klass', img.__class__)
540541
with InTemporaryDirectory():
541542
fname = 'img' + self.standard_extension
542543
img.to_filename(fname)
543544

544-
all_images = list(getattr(self, 'example_images', [])) + [{'fname': fname}]
545-
for img_params in all_images:
546-
img_a = klass.from_filename(img_params['fname'])
547-
with open(img_params['fname'], 'rb') as fobj:
548-
img_b = klass.from_bytes(fobj.read())
545+
with open("stream", "wb") as fobj:
546+
img.to_stream(fobj)
549547

550-
assert self._header_eq(img_a.header, img_b.header)
548+
# Check that writing gets us the same thing
549+
contents1 = pathlib.Path(fname).read_bytes()
550+
contents2 = pathlib.Path("stream").read_bytes()
551+
assert contents1 == contents2
552+
553+
# Check that reading gets us the same thing
554+
img_a = klass.from_filename(fname)
555+
with open(fname, "rb") as fobj:
556+
img_b = klass.from_stream(fobj)
557+
# This needs to happen while the filehandle is open
551558
assert np.array_equal(img_a.get_fdata(), img_b.get_fdata())
552-
del img_a
553-
del img_b
559+
assert self._header_eq(img_a.header, img_b.header)
560+
del img_a
561+
del img_b
554562

555563
def validate_to_from_bytes(self, imaker, params):
556564
img = imaker()
@@ -572,6 +580,45 @@ def validate_to_from_bytes(self, imaker, params):
572580
del img_a
573581
del img_b
574582

583+
@pytest.fixture(autouse=True)
584+
def setup(self, httpserver, tmp_path):
585+
"""Make pytest fixtures available to validate functions"""
586+
self.httpserver = httpserver
587+
self.tmp_path = tmp_path
588+
589+
def validate_from_url(self, imaker, params):
590+
server = self.httpserver
591+
592+
img = imaker()
593+
img_bytes = img.to_bytes()
594+
595+
server.expect_oneshot_request("/img").respond_with_data(img_bytes)
596+
url = server.url_for("/img")
597+
assert url.startswith("http://") # Check we'll trigger an HTTP handler
598+
rt_img = img.__class__.from_url(url)
599+
600+
assert rt_img.to_bytes() == img_bytes
601+
assert self._header_eq(img.header, rt_img.header)
602+
assert np.array_equal(img.get_fdata(), rt_img.get_fdata())
603+
del img
604+
del rt_img
605+
606+
def validate_from_file_url(self, imaker, params):
607+
tmp_path = self.tmp_path
608+
609+
img = imaker()
610+
import uuid
611+
fname = tmp_path / f'img-{uuid.uuid4()}{self.standard_extension}'
612+
img.to_filename(fname)
613+
614+
rt_img = img.__class__.from_url(f"file:///{fname}")
615+
616+
assert self._header_eq(img.header, rt_img.header)
617+
assert np.array_equal(img.get_fdata(), rt_img.get_fdata())
618+
del img
619+
del rt_img
620+
621+
575622
@staticmethod
576623
def _header_eq(header_a, header_b):
577624
""" Header equality check that can be overridden by a subclass of this test
@@ -583,7 +630,6 @@ def _header_eq(header_a, header_b):
583630
return header_a == header_b
584631

585632

586-
587633
class LoadImageAPI(GenericImageAPI,
588634
DataInterfaceMixin,
589635
AffineMixin,

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ test =
6161
pytest !=5.3.4
6262
pytest-cov
6363
pytest-doctestplus
64+
pytest-httpserver
6465
zstd =
6566
pyzstd >= 0.14.3
6667
all =

0 commit comments

Comments
 (0)