Skip to content

Commit aacc5f2

Browse files
authored
Add CacheMapper to map from remote URL to local cached basename (#1296)
* Add CacheMapper to map from remote URL to local cached basename * Raise exception if 'fn' not in cached metadata * Fix tests on Windows
1 parent 285094f commit aacc5f2

File tree

4 files changed

+157
-29
lines changed

4 files changed

+157
-29
lines changed
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from __future__ import annotations
2+
3+
import abc
4+
import hashlib
5+
import os
6+
from typing import TYPE_CHECKING
7+
8+
if TYPE_CHECKING:
9+
from typing import Any
10+
11+
12+
class AbstractCacheMapper(abc.ABC):
13+
"""Abstract super-class for mappers from remote URLs to local cached
14+
basenames.
15+
"""
16+
17+
@abc.abstractmethod
18+
def __call__(self, path: str) -> str:
19+
...
20+
21+
def __eq__(self, other: Any) -> bool:
22+
# Identity only depends on class. When derived classes have attributes
23+
# they will need to be included.
24+
return isinstance(other, type(self))
25+
26+
def __hash__(self) -> int:
27+
# Identity only depends on class. When derived classes have attributes
28+
# they will need to be included.
29+
return hash(type(self))
30+
31+
32+
class BasenameCacheMapper(AbstractCacheMapper):
33+
"""Cache mapper that uses the basename of the remote URL.
34+
35+
Different paths with the same basename will therefore have the same cached
36+
basename.
37+
"""
38+
39+
def __call__(self, path: str) -> str:
40+
return os.path.basename(path)
41+
42+
43+
class HashCacheMapper(AbstractCacheMapper):
44+
"""Cache mapper that uses a hash of the remote URL."""
45+
46+
def __call__(self, path: str) -> str:
47+
return hashlib.sha256(path.encode()).hexdigest()
48+
49+
50+
def create_cache_mapper(same_names: bool) -> AbstractCacheMapper:
51+
"""Factory method to create cache mapper for backward compatibility with
52+
``CachingFileSystem`` constructor using ``same_names`` kwarg.
53+
"""
54+
if same_names:
55+
return BasenameCacheMapper()
56+
else:
57+
return HashCacheMapper()

fsspec/implementations/cached.py

Lines changed: 21 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,21 @@
11
from __future__ import annotations
22

33
import contextlib
4-
import hashlib
54
import inspect
65
import logging
76
import os
87
import pickle
98
import tempfile
109
import time
1110
from shutil import rmtree
12-
from typing import ClassVar
11+
from typing import Any, ClassVar
1312

1413
from fsspec import AbstractFileSystem, filesystem
1514
from fsspec.callbacks import _DEFAULT_CALLBACK
1615
from fsspec.compression import compr
1716
from fsspec.core import BaseCache, MMapCache
1817
from fsspec.exceptions import BlocksizeMismatchError
18+
from fsspec.implementations.cache_mapper import create_cache_mapper
1919
from fsspec.spec import AbstractBufferedFile
2020
from fsspec.utils import infer_compression
2121

@@ -115,9 +115,7 @@ def __init__(
115115
self.check_files = check_files
116116
self.expiry = expiry_time
117117
self.compression = compression
118-
# TODO: same_names should allow for variable prefix, not only
119-
# to keep the basename
120-
self.same_names = same_names
118+
self._mapper = create_cache_mapper(same_names)
121119
self.target_protocol = (
122120
target_protocol
123121
if isinstance(target_protocol, str)
@@ -255,11 +253,12 @@ def clear_expired_cache(self, expiry_time=None):
255253

256254
for path, detail in self.cached_files[-1].copy().items():
257255
if time.time() - detail["time"] > expiry_time:
258-
if self.same_names:
259-
basename = os.path.basename(detail["original"])
260-
fn = os.path.join(self.storage[-1], basename)
261-
else:
262-
fn = os.path.join(self.storage[-1], detail["fn"])
256+
fn = detail.get("fn", "")
257+
if not fn:
258+
raise RuntimeError(
259+
f"Cache metadata does not contain 'fn' for {path}"
260+
)
261+
fn = os.path.join(self.storage[-1], fn)
263262
if os.path.exists(fn):
264263
os.remove(fn)
265264
self.cached_files[-1].pop(path)
@@ -339,7 +338,7 @@ def _open(
339338
# TODO: action where partial file exists in read-only cache
340339
logger.debug("Opening partially cached copy of %s" % path)
341340
else:
342-
hash = self.hash_name(path, self.same_names)
341+
hash = self._mapper(path)
343342
fn = os.path.join(self.storage[-1], hash)
344343
blocks = set()
345344
detail = {
@@ -385,8 +384,10 @@ def _open(
385384
self.save_cache()
386385
return f
387386

388-
def hash_name(self, path, same_name):
389-
return hash_name(path, same_name=same_name)
387+
def hash_name(self, path: str, *args: Any) -> str:
388+
# Kept for backward compatibility with downstream libraries.
389+
# Ignores extra arguments, previously same_name boolean.
390+
return self._mapper(path)
390391

391392
def close_and_update(self, f, close):
392393
"""Called when a file is closing, so store the set of blocks"""
@@ -488,7 +489,7 @@ def __eq__(self, other):
488489
and self.check_files == other.check_files
489490
and self.expiry == other.expiry
490491
and self.compression == other.compression
491-
and self.same_names == other.same_names
492+
and self._mapper == other._mapper
492493
and self.target_protocol == other.target_protocol
493494
)
494495

@@ -501,7 +502,7 @@ def __hash__(self):
501502
^ hash(self.check_files)
502503
^ hash(self.expiry)
503504
^ hash(self.compression)
504-
^ hash(self.same_names)
505+
^ hash(self._mapper)
505506
^ hash(self.target_protocol)
506507
)
507508

@@ -546,7 +547,7 @@ def open_many(self, open_files):
546547
details = [self._check_file(sp) for sp in paths]
547548
downpath = [p for p, d in zip(paths, details) if not d]
548549
downfn0 = [
549-
os.path.join(self.storage[-1], self.hash_name(p, self.same_names))
550+
os.path.join(self.storage[-1], self._mapper(p))
550551
for p, d in zip(paths, details)
551552
] # keep these path names for opening later
552553
downfn = [fn for fn, d in zip(downfn0, details) if not d]
@@ -558,7 +559,7 @@ def open_many(self, open_files):
558559
newdetail = [
559560
{
560561
"original": path,
561-
"fn": self.hash_name(path, self.same_names),
562+
"fn": self._mapper(path),
562563
"blocks": True,
563564
"time": time.time(),
564565
"uid": self.fs.ukey(path),
@@ -590,7 +591,7 @@ def commit_many(self, open_files):
590591
pass
591592

592593
def _make_local_details(self, path):
593-
hash = self.hash_name(path, self.same_names)
594+
hash = self._mapper(path)
594595
fn = os.path.join(self.storage[-1], hash)
595596
detail = {
596597
"original": path,
@@ -731,7 +732,7 @@ def __init__(self, **kwargs):
731732

732733
def _check_file(self, path):
733734
self._check_cache()
734-
sha = self.hash_name(path, self.same_names)
735+
sha = self._mapper(path)
735736
for storage in self.storage:
736737
fn = os.path.join(storage, sha)
737738
if os.path.exists(fn):
@@ -752,7 +753,7 @@ def _open(self, path, mode="rb", **kwargs):
752753
if fn:
753754
return open(fn, mode)
754755

755-
sha = self.hash_name(path, self.same_names)
756+
sha = self._mapper(path)
756757
fn = os.path.join(self.storage[-1], sha)
757758
logger.debug("Copying %s to local cache" % path)
758759
kwargs["mode"] = mode
@@ -838,14 +839,6 @@ def __getattr__(self, item):
838839
return getattr(self.fh, item)
839840

840841

841-
def hash_name(path, same_name):
842-
if same_name:
843-
hash = os.path.basename(path)
844-
else:
845-
hash = hashlib.sha256(path.encode()).hexdigest()
846-
return hash
847-
848-
849842
@contextlib.contextmanager
850843
def atomic_write(path, mode="wb"):
851844
"""

fsspec/implementations/tests/test_cached.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
import fsspec
99
from fsspec.compression import compr
1010
from fsspec.exceptions import BlocksizeMismatchError
11+
from fsspec.implementations.cache_mapper import create_cache_mapper
1112
from fsspec.implementations.cached import CachingFileSystem, LocalTempFile
13+
from fsspec.implementations.local import make_path_posix
1214

1315
from .test_ftp import FTPFileSystem
1416

@@ -32,6 +34,61 @@ def local_filecache():
3234
return data, original_file, cache_location, fs
3335

3436

37+
def test_mapper():
38+
mapper0 = create_cache_mapper(True)
39+
assert mapper0("/somedir/somefile") == "somefile"
40+
assert mapper0("/otherdir/somefile") == "somefile"
41+
42+
mapper1 = create_cache_mapper(False)
43+
assert (
44+
mapper1("/somedir/somefile")
45+
== "67a6956e5a5f95231263f03758c1fd9254fdb1c564d311674cec56b0372d2056"
46+
)
47+
assert (
48+
mapper1("/otherdir/somefile")
49+
== "f043dee01ab9b752c7f2ecaeb1a5e1b2d872018e2d0a1a26c43835ebf34e7d3e"
50+
)
51+
52+
assert mapper0 != mapper1
53+
assert create_cache_mapper(True) == mapper0
54+
assert create_cache_mapper(False) == mapper1
55+
56+
assert hash(mapper0) != hash(mapper1)
57+
assert hash(create_cache_mapper(True)) == hash(mapper0)
58+
assert hash(create_cache_mapper(False)) == hash(mapper1)
59+
60+
61+
@pytest.mark.parametrize("same_names", [False, True])
62+
def test_metadata(tmpdir, same_names):
63+
source = os.path.join(tmpdir, "source")
64+
afile = os.path.join(source, "afile")
65+
os.mkdir(source)
66+
open(afile, "w").write("test")
67+
68+
fs = fsspec.filesystem(
69+
"filecache",
70+
target_protocol="file",
71+
cache_storage=os.path.join(tmpdir, "cache"),
72+
same_names=same_names,
73+
)
74+
75+
with fs.open(afile, "rb") as f:
76+
assert f.read(5) == b"test"
77+
78+
afile_posix = make_path_posix(afile)
79+
detail = fs.cached_files[0][afile_posix]
80+
assert sorted(detail.keys()) == ["blocks", "fn", "original", "time", "uid"]
81+
assert isinstance(detail["blocks"], bool)
82+
assert isinstance(detail["fn"], str)
83+
assert isinstance(detail["time"], float)
84+
assert isinstance(detail["uid"], str)
85+
86+
assert detail["original"] == afile_posix
87+
assert detail["fn"] == fs._mapper(afile_posix)
88+
if same_names:
89+
assert detail["fn"] == "afile"
90+
91+
3592
def test_idempotent():
3693
fs = CachingFileSystem("file")
3794
fs2 = CachingFileSystem("file")
@@ -154,7 +211,7 @@ def test_clear():
154211

155212

156213
def test_clear_expired(tmp_path):
157-
def __ager(cache_fn, fn):
214+
def __ager(cache_fn, fn, del_fn=False):
158215
"""
159216
Modify the cache file to virtually add time lag to selected files.
160217
@@ -164,6 +221,8 @@ def __ager(cache_fn, fn):
164221
cache path
165222
fn: str
166223
file name to be modified
224+
del_fn: bool
225+
whether or not to delete 'fn' from cache details
167226
"""
168227
import pathlib
169228
import time
@@ -174,6 +233,8 @@ def __ager(cache_fn, fn):
174233
fn_posix = pathlib.Path(fn).as_posix()
175234
cached_files[fn_posix]["time"] = cached_files[fn_posix]["time"] - 691200
176235
assert os.access(cache_fn, os.W_OK), "Cache is not writable"
236+
if del_fn:
237+
del cached_files[fn_posix]["fn"]
177238
with open(cache_fn, "wb") as f:
178239
pickle.dump(cached_files, f)
179240
time.sleep(1)
@@ -255,6 +316,22 @@ def __ager(cache_fn, fn):
255316
fs.clear_expired_cache()
256317
assert not fs._check_file(str(f4))
257318

319+
# check cache metadata lacking 'fn' raises RuntimeError.
320+
fs = fsspec.filesystem(
321+
"filecache",
322+
target_protocol="file",
323+
cache_storage=str(cache1),
324+
same_names=True,
325+
cache_check=1,
326+
)
327+
assert fs.cat(str(f1)) == data
328+
329+
cache_fn = os.path.join(fs.storage[-1], "cache")
330+
__ager(cache_fn, f1, del_fn=True)
331+
332+
with pytest.raises(RuntimeError, match="Cache metadata does not contain 'fn' for"):
333+
fs.clear_expired_cache()
334+
258335

259336
def test_pop():
260337
import tempfile

fsspec/tests/test_api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,7 @@ def test_chained_equivalent():
308308
# since the parameters don't quite match. Also, the url understood by the two
309309
# of s are not the same (path gets munged a bit differently)
310310
assert of.fs == of2.fs
311+
assert hash(of.fs) == hash(of2.fs)
311312
assert of.open().read() == of2.open().read()
312313

313314

0 commit comments

Comments
 (0)