Skip to content

Commit dfe09a8

Browse files
committed
TYP: Annotate openers
Opener proxy methods now match io.BufferedIOBase prototypes. Remove some version checks for indexed-gzip < 0.8, which supported Python 3.6 while our minimum is now 3.8. A runtime-checkable protocol for .read()/.write() was the easiest way to accommodate weird file-likes that aren't IOBases. When indexed-gzip is typed, we may need to adjust the output of _gzip_open.
1 parent 504776c commit dfe09a8

File tree

2 files changed

+115
-67
lines changed

2 files changed

+115
-67
lines changed

nibabel/openers.py

Lines changed: 114 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -8,35 +8,48 @@
88
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
99
"""Context manager openers for various fileobject types
1010
"""
11+
from __future__ import annotations
1112

1213
import gzip
13-
import warnings
14+
import io
15+
import typing as ty
1416
from bz2 import BZ2File
1517
from os.path import splitext
1618

17-
from packaging.version import Version
18-
1919
from nibabel.optpkg import optional_package
2020

21-
# is indexed_gzip present and modern?
22-
try:
23-
import indexed_gzip as igzip # type: ignore
21+
if ty.TYPE_CHECKING: # pragma: no cover
22+
from types import TracebackType
2423

25-
version = igzip.__version__
24+
import pyzstd
25+
from _typeshed import WriteableBuffer
2626

27-
HAVE_INDEXED_GZIP = True
27+
ModeRT = ty.Literal['r', 'rt']
28+
ModeRB = ty.Literal['rb']
29+
ModeWT = ty.Literal['w', 'wt']
30+
ModeWB = ty.Literal['wb']
31+
ModeR = ty.Union[ModeRT, ModeRB]
32+
ModeW = ty.Union[ModeWT, ModeWB]
33+
Mode = ty.Union[ModeR, ModeW]
34+
35+
OpenerDef = tuple[ty.Callable[..., io.IOBase], tuple[str, ...]]
36+
else:
37+
pyzstd = optional_package('pyzstd')[0]
2838

29-
# < 0.7 - no good
30-
if Version(version) < Version('0.7.0'):
31-
warnings.warn(f'indexed_gzip is present, but too old (>= 0.7.0 required): {version})')
32-
HAVE_INDEXED_GZIP = False
33-
# >= 0.8 SafeIndexedGzipFile renamed to IndexedGzipFile
34-
elif Version(version) < Version('0.8.0'):
35-
IndexedGzipFile = igzip.SafeIndexedGzipFile
36-
else:
37-
IndexedGzipFile = igzip.IndexedGzipFile
38-
del igzip, version
3939

40+
@ty.runtime_checkable
41+
class Fileish(ty.Protocol):
42+
def read(self, size: int = -1, /) -> bytes:
43+
... # pragma: no cover
44+
45+
def write(self, b: bytes, /) -> int | None:
46+
... # pragma: no cover
47+
48+
49+
try:
50+
from indexed_gzip import IndexedGzipFile # type: ignore
51+
52+
HAVE_INDEXED_GZIP = True
4053
except ImportError:
4154
# nibabel.openers.IndexedGzipFile is imported by nibabel.volumeutils
4255
# to detect compressed file types, so we give a fallback value here.
@@ -51,35 +64,63 @@ class DeterministicGzipFile(gzip.GzipFile):
5164
to a modification time (``mtime``) of 0 seconds.
5265
"""
5366

54-
def __init__(self, filename=None, mode=None, compresslevel=9, fileobj=None, mtime=0):
55-
# These two guards are copied from
67+
def __init__(
68+
self,
69+
filename: str | None = None,
70+
mode: Mode | None = None,
71+
compresslevel: int = 9,
72+
fileobj: io.FileIO | None = None,
73+
mtime: int = 0,
74+
):
75+
if mode is None:
76+
mode = 'rb'
77+
modestr: str = mode
78+
79+
# These two guards are adapted from
5680
# https://github.com/python/cpython/blob/6ab65c6/Lib/gzip.py#L171-L174
57-
if mode and 'b' not in mode:
58-
mode += 'b'
81+
if 'b' not in modestr:
82+
modestr = f'{mode}b'
5983
if fileobj is None:
60-
fileobj = self.myfileobj = open(filename, mode or 'rb')
84+
if filename is None:
85+
raise TypeError('Must define either fileobj or filename')
86+
# Cast because GzipFile.myfileobj has type io.FileIO while open returns ty.IO
87+
fileobj = self.myfileobj = ty.cast(io.FileIO, open(filename, modestr))
6188
return super().__init__(
62-
filename='', mode=mode, compresslevel=compresslevel, fileobj=fileobj, mtime=mtime
89+
filename='',
90+
mode=modestr,
91+
compresslevel=compresslevel,
92+
fileobj=fileobj,
93+
mtime=mtime,
6394
)
6495

6596

66-
def _gzip_open(filename, mode='rb', compresslevel=9, mtime=0, keep_open=False):
97+
def _gzip_open(
98+
filename: str,
99+
mode: Mode = 'rb',
100+
compresslevel: int = 9,
101+
mtime: int = 0,
102+
keep_open: bool = False,
103+
) -> gzip.GzipFile:
104+
105+
if not HAVE_INDEXED_GZIP or mode != 'rb':
106+
gzip_file = DeterministicGzipFile(filename, mode, compresslevel, mtime=mtime)
67107

68108
# use indexed_gzip if possible for faster read access. If keep_open ==
69109
# True, we tell IndexedGzipFile to keep the file handle open. Otherwise
70110
# the IndexedGzipFile will close/open the file on each read.
71-
if HAVE_INDEXED_GZIP and mode == 'rb':
72-
gzip_file = IndexedGzipFile(filename, drop_handles=not keep_open)
73-
74-
# Fall-back to built-in GzipFile
75111
else:
76-
gzip_file = DeterministicGzipFile(filename, mode, compresslevel, mtime=mtime)
112+
gzip_file = IndexedGzipFile(filename, drop_handles=not keep_open)
77113

78114
return gzip_file
79115

80116

81-
def _zstd_open(filename, mode='r', *, level_or_option=None, zstd_dict=None):
82-
pyzstd = optional_package('pyzstd')[0]
117+
def _zstd_open(
118+
filename: str,
119+
mode: Mode = 'r',
120+
*,
121+
level_or_option: int | dict | None = None,
122+
zstd_dict: pyzstd.ZstdDict | None = None,
123+
) -> pyzstd.ZstdFile:
83124
return pyzstd.ZstdFile(filename, mode, level_or_option=level_or_option, zstd_dict=zstd_dict)
84125

85126

@@ -106,7 +147,7 @@ class Opener:
106147
gz_def = (_gzip_open, ('mode', 'compresslevel', 'mtime', 'keep_open'))
107148
bz2_def = (BZ2File, ('mode', 'buffering', 'compresslevel'))
108149
zstd_def = (_zstd_open, ('mode', 'level_or_option', 'zstd_dict'))
109-
compress_ext_map = {
150+
compress_ext_map: dict[str | None, OpenerDef] = {
110151
'.gz': gz_def,
111152
'.bz2': bz2_def,
112153
'.zst': zstd_def,
@@ -123,19 +164,19 @@ class Opener:
123164
'w': default_zst_compresslevel,
124165
}
125166
#: whether to ignore case looking for compression extensions
126-
compress_ext_icase = True
167+
compress_ext_icase: bool = True
127168

128-
def __init__(self, fileish, *args, **kwargs):
129-
if self._is_fileobj(fileish):
169+
fobj: io.IOBase
170+
171+
def __init__(self, fileish: str | io.IOBase, *args, **kwargs):
172+
if isinstance(fileish, (io.IOBase, Fileish)):
130173
self.fobj = fileish
131174
self.me_opened = False
132-
self._name = None
175+
self._name = getattr(fileish, 'name', None)
133176
return
134177
opener, arg_names = self._get_opener_argnames(fileish)
135178
# Get full arguments to check for mode and compresslevel
136-
full_kwargs = kwargs.copy()
137-
n_args = len(args)
138-
full_kwargs.update(dict(zip(arg_names[:n_args], args)))
179+
full_kwargs = {**kwargs, **dict(zip(arg_names, args))}
139180
# Set default mode
140181
if 'mode' not in full_kwargs:
141182
mode = 'rb'
@@ -157,7 +198,7 @@ def __init__(self, fileish, *args, **kwargs):
157198
self._name = fileish
158199
self.me_opened = True
159200

160-
def _get_opener_argnames(self, fileish):
201+
def _get_opener_argnames(self, fileish: str) -> OpenerDef:
161202
_, ext = splitext(fileish)
162203
if self.compress_ext_icase:
163204
ext = ext.lower()
@@ -170,16 +211,12 @@ def _get_opener_argnames(self, fileish):
170211
return self.compress_ext_map[ext]
171212
return self.compress_ext_map[None]
172213

173-
def _is_fileobj(self, obj):
174-
"""Is `obj` a file-like object?"""
175-
return hasattr(obj, 'read') and hasattr(obj, 'write')
176-
177214
@property
178-
def closed(self):
215+
def closed(self) -> bool:
179216
return self.fobj.closed
180217

181218
@property
182-
def name(self):
219+
def name(self) -> str | None:
183220
"""Return ``self.fobj.name`` or self._name if not present
184221
185222
self._name will be None if object was created with a fileobj, otherwise
@@ -188,42 +225,53 @@ def name(self):
188225
return self._name
189226

190227
@property
191-
def mode(self):
192-
return self.fobj.mode
228+
def mode(self) -> str:
229+
# Check and raise our own error for type narrowing purposes
230+
if hasattr(self.fobj, 'mode'):
231+
return self.fobj.mode
232+
raise AttributeError(f'{self.fobj.__class__.__name__} has no attribute "mode"')
193233

194-
def fileno(self):
234+
def fileno(self) -> int:
195235
return self.fobj.fileno()
196236

197-
def read(self, *args, **kwargs):
198-
return self.fobj.read(*args, **kwargs)
237+
def read(self, size: int = -1, /) -> bytes:
238+
return self.fobj.read(size)
199239

200-
def readinto(self, *args, **kwargs):
201-
return self.fobj.readinto(*args, **kwargs)
240+
def readinto(self, buffer: WriteableBuffer, /) -> int | None:
241+
# Check and raise our own error for type narrowing purposes
242+
if hasattr(self.fobj, 'readinto'):
243+
return self.fobj.readinto(buffer)
244+
raise AttributeError(f'{self.fobj.__class__.__name__} has no attribute "readinto"')
202245

203-
def write(self, *args, **kwargs):
204-
return self.fobj.write(*args, **kwargs)
246+
def write(self, b: bytes, /) -> int | None:
247+
return self.fobj.write(b)
205248

206-
def seek(self, *args, **kwargs):
207-
return self.fobj.seek(*args, **kwargs)
249+
def seek(self, pos: int, whence: int = 0, /) -> int:
250+
return self.fobj.seek(pos, whence)
208251

209-
def tell(self, *args, **kwargs):
210-
return self.fobj.tell(*args, **kwargs)
252+
def tell(self, /) -> int:
253+
return self.fobj.tell()
211254

212-
def close(self, *args, **kwargs):
213-
return self.fobj.close(*args, **kwargs)
255+
def close(self, /) -> None:
256+
return self.fobj.close()
214257

215-
def __iter__(self):
258+
def __iter__(self) -> ty.Iterator[bytes]:
216259
return iter(self.fobj)
217260

218-
def close_if_mine(self):
261+
def close_if_mine(self) -> None:
219262
"""Close ``self.fobj`` iff we opened it in the constructor"""
220263
if self.me_opened:
221264
self.close()
222265

223-
def __enter__(self):
266+
def __enter__(self) -> Opener:
224267
return self
225268

226-
def __exit__(self, exc_type, exc_val, exc_tb):
269+
def __exit__(
270+
self,
271+
exc_type: type[BaseException] | None,
272+
exc_val: BaseException | None,
273+
exc_tb: TracebackType | None,
274+
) -> None:
227275
self.close_if_mine()
228276

229277

nibabel/tests/test_openers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__(self, message):
3838
def write(self):
3939
pass
4040

41-
def read(self):
41+
def read(self, size=-1, /):
4242
return self.message
4343

4444

0 commit comments

Comments
 (0)