Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 105 additions & 78 deletions s3fs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def __init__(self, anon=None, key=None, secret=None, token=None,
try:
self.anon = False
self.s3 = self.connect()
self.ls('')
self.get_delegated_s3pars()
return
except:
logger.debug('Accredited connection failed, trying anonymous')
Expand Down Expand Up @@ -157,13 +157,15 @@ def connect(self, refresh=False):
conf = Config(connect_timeout=self.connect_timeout,
read_timeout=self.read_timeout,
signature_version=UNSIGNED)
s3 = boto3.Session(**self.kwargs).client('s3', config=conf,
self.session = boto3.Session(**self.kwargs)
s3 = self.session.client('s3', config=conf,
use_ssl=ssl)
else:
conf = Config(connect_timeout=self.connect_timeout,
read_timeout=self.read_timeout)
s3 = boto3.Session(self.key, self.secret, self.token,
**self.kwargs).client('s3', config=conf,
self.session = boto3.Session(self.key, self.secret, self.token,
**self.kwargs)
s3 = self.session.client('s3', config=conf,
use_ssl=ssl)
self._conn[tok] = s3
return self._conn[tok]
Expand All @@ -185,8 +187,7 @@ def get_delegated_s3pars(self, exp=3600):
if self.token: # already has temporary cred
return {'key': self.key, 'secret': self.secret, 'token': self.token,
'anon': False}
sts = boto3.Session(self.key, self.secret, self.token,
**self.kwargs).client('sts')
sts = self.session.client('sts')
cred = sts.get_session_token(DurationSeconds=3600)['Credentials']
return {'key': cred['AccessKeyId'], 'secret': cred['SecretAccessKey'],
'token': cred['SessionToken'], 'anon': False}
Expand Down Expand Up @@ -218,6 +219,47 @@ def open(self, path, mode='rb', block_size=5 * 1024 ** 2):
" and manage bytes" % (mode[0] + 'b'))
return S3File(self, path, mode, block_size=block_size)

def _lsdir(self, path, refresh=False):
if path.startswith('s3://'):
path = path[len('s3://'):]
path = path.rstrip('/')
bucket, prefix = split_path(path)
prefix = prefix + '/' if prefix else ""
if path not in self.dirs or refresh:
try:
pag = self.s3.get_paginator('list_objects')
it = pag.paginate(Bucket=bucket, Prefix=prefix, Delimiter='/')
files = []
dirs = None
for i in it:
dirs = dirs or i.get('CommonPrefixes', None)
files.extend(i.get('Contents', []))
if dirs:
files.extend([{'Key': l['Prefix'][:-1], 'Size': 0,
'StorageClass': "DIRECTORY"} for l in dirs])
files = [f for f in files if len(f['Key']) > len(prefix)]
for f in files:
f['Key'] = '/'.join([bucket, f['Key']])
except ClientError:
# path not accessible
files = []
self.dirs[path] = files
return self.dirs[path]

def _lsbuckets(self, refresh=False):
if '' not in self.dirs or refresh:
if self.anon:
# cannot list buckets if not logged in
return []
files = self.s3.list_buckets()['Buckets']
for f in files:
f['Key'] = f['Name']
f['Size'] = 0
f['StorageClass'] = 'BUCKET'
del f['Name']
self.dirs[''] = files
return self.dirs['']

def _ls(self, path, refresh=False):
""" List files in given bucket, or list of buckets.

Expand All @@ -238,49 +280,22 @@ def _ls(self, path, refresh=False):
"""
if path.startswith('s3://'):
path = path[len('s3://'):]
path = path.rstrip('/')
bucket, key = split_path(path)
if bucket not in self.dirs or refresh:
if bucket == '':
# list of buckets
if self.anon:
# cannot list buckets if not logged in
return []
files = self.s3.list_buckets()['Buckets']
for f in files:
f['Key'] = f['Name']
f['Size'] = 0
del f['Name']
else:
try:
pag = self.s3.get_paginator('list_objects')
it = pag.paginate(Bucket=bucket)
files = []
for i in it:
files.extend(i.get('Contents', []))
except ClientError:
# bucket not accessible
raise FileNotFoundError(bucket)
for f in files:
f['Key'] = "/".join([bucket, f['Key']])
self.dirs[bucket] = list(sorted(files, key=lambda x: x['Key']))
files = self.dirs[bucket]
return files
if path in ['', '/']:
return self._lsbuckets(refresh)
else:
return self._lsdir(path, refresh)

def ls(self, path, detail=False, refresh=False):
""" List single "directory" with or without details """
if path.startswith('s3://'):
path = path[len('s3://'):]
path = path.rstrip('/')
files = self._ls(path, refresh=refresh)
if path:
pattern = re.compile(path + '/[^/]*.$')
files = [f for f in files if pattern.match(f['Key']) is not None]
if not files:
try:
files = [self.info(path)]
except (OSError, IOError, ClientError):
files = []
if not files:
if split_path(path)[1]:
files = [self.info(path)]
elif path:
raise FileNotFoundError(path)
if detail:
return files
else:
Expand All @@ -289,26 +304,41 @@ def ls(self, path, detail=False, refresh=False):
def info(self, path, refresh=False):
""" Detail on the specific file pointed to by path.

NB: path has trailing '/' stripped to work as `ls` does, so key
names that genuinely end in '/' will fail.
Gets details only for a specific key, directories/buckets cannot be
used with info.
"""
if path.startswith('s3://'):
path = path[len('s3://'):]
path = path.rstrip('/')
files = self._ls(path, refresh=refresh)
files = [f for f in files if f['Key'].rstrip('/') == path]
parent = path.rsplit('/', 1)[0]
files = self._lsdir(parent, refresh=refresh)
files = [f for f in files if f['Key'] == path and f['StorageClass'] not
in ['DIRECTORY', 'BUCKET']]
if len(files) == 1:
return files[0]
else:
raise IOError("File not found: %s" % path)
try:
bucket, key = split_path(path)
out = self.s3.head_object(Bucket=bucket, Key=key)
out = {'ETag': out['ETag'], 'Key': '/'.join([bucket, key]),
'LastModified': out['LastModified'],
'Size': out['ContentLength'], 'StorageClass': "STANDARD"}
return out
except (ClientError, ParamValidationError):
raise FileNotFoundError(path)

def walk(self, path, refresh=False):
""" Return all entries below path """
def _walk(self, path, refresh=False):
if path.startswith('s3://'):
path = path[len('s3://'):]
filenames = self._ls(path, refresh=refresh)
return [f['Key'] for f in filenames if f['Key'].rstrip('/'
).startswith(path.rstrip('/') + '/')]
if path in ['', '/']:
raise ValueError('Cannot walk all of S3')
filenames = self._ls(path, refresh=refresh)[:]
for f in filenames[:]:
if f['StorageClass'] == 'DIRECTORY':
filenames.extend(self._walk(f['Key'], refresh))
return [f for f in filenames if f['StorageClass'] not in
['BUCKET', 'DIRECTORY']]

def walk(self, path, refresh=False):
""" Return all real keys below path """
return [f['Key'] for f in self._walk(path, refresh)]

def glob(self, path):
"""
Expand Down Expand Up @@ -354,15 +384,12 @@ def du(self, path, total=False, deep=False):
return {p['Key']: p['Size'] for p in files}

def exists(self, path):
""" Does such a file exist? """
if path.startswith('s3://'):
path = path[len('s3://'):]
path = path.rstrip('/')
if split_path(path)[1]:
return bool(self.ls(path))
""" Does such a file/directory exist? """
bucket, key = split_path(path)
if key:
return not raises(FileNotFoundError, lambda: self.ls(path))
else:
return (path in self.ls('') or
not raises(FileNotFoundError, lambda: self.ls(path)))
return bucket in self.ls('')

def cat(self, path):
""" Returns contents of file """
Expand Down Expand Up @@ -441,7 +468,7 @@ def merge(self, path, filelist):
part_info = {'Parts': parts}
self.s3.complete_multipart_upload(Bucket=bucket, Key=key,
UploadId=mpu['UploadId'], MultipartUpload=part_info)
self.invalidate_cache(bucket)
self.invalidate_cache(path)

def copy(self, path1, path2):
""" Copy file between locations on S3 """
Expand All @@ -452,7 +479,7 @@ def copy(self, path1, path2):
CopySource='/'.join([buc1, key1]))
except (ClientError, ParamValidationError):
raise IOError('Copy failed', (path1, path2))
self.invalidate_cache(buc2)
self.invalidate_cache(path2)

def bulk_delete(self, pathlist):
"""
Expand All @@ -471,14 +498,14 @@ def bulk_delete(self, pathlist):
bucket = buckets.pop()
if len(pathlist) > 1000:
for i in range((len(pathlist) // 1000) + 1):
print(i)
self.bulk_delete(pathlist[i*1000:(i+1)*1000])
return
delete_keys = {'Objects': [{'Key' : split_path(path)[1]} for path
in pathlist]}
try:
self.s3.delete_objects(Bucket=bucket, Delete=delete_keys)
self.invalidate_cache(bucket)
for path in pathlist:
self.invalidate_cache(path)
except ClientError:
raise IOError('Bulk delete failed')

Expand Down Expand Up @@ -506,24 +533,25 @@ def rm(self, path, recursive=False):
self.s3.delete_object(Bucket=bucket, Key=key)
except ClientError:
raise IOError('Delete key failed', (bucket, key))
self.invalidate_cache(bucket)
self.invalidate_cache(path)
else:
if not self.s3.list_objects(Bucket=bucket).get('Contents'):
try:
self.s3.delete_bucket(Bucket=bucket)
except ClientError:
raise IOError('Delete bucket failed', bucket)
self.dirs.pop(bucket, None)
self.invalidate_cache(bucket)
self.invalidate_cache('')
else:
raise IOError('Not empty', path)

def invalidate_cache(self, bucket=None):
if bucket is None:
def invalidate_cache(self, path=None):
if path is None:
self.dirs.clear()
elif bucket in self.dirs:
del self.dirs[bucket]
else:
self.dirs.pop(path, None)
parent = path.rsplit('/', 1)[0]
self.dirs.pop(parent, None)

def touch(self, path):
"""
Expand All @@ -534,11 +562,12 @@ def touch(self, path):
bucket, key = split_path(path)
if key:
self.s3.put_object(Bucket=bucket, Key=key)
self.invalidate_cache(bucket)
self.invalidate_cache(path)
else:
try:
self.s3.create_bucket(Bucket=bucket)
self.invalidate_cache('')
self.invalidate_cache(bucket)
except (ClientError, ParamValidationError):
raise IOError('Bucket create failed', path)

Expand Down Expand Up @@ -662,9 +691,7 @@ def __init__(self, s3, path, mode='rb', block_size=5 * 2 ** 20):

def info(self):
""" File information about this path """
info = self.s3.s3.head_object(Bucket=self.bucket, Key=self.key)
info['Size'] = info.get('ContentLength')
return info
return self.s3.info(self.path)

def tell(self):
""" Current file location """
Expand Down Expand Up @@ -875,7 +902,7 @@ def close(self):
Body=self.buffer.read())
except (ClientError, ParamValidationError) as e:
raise IOError('Write failed: %s' % self.path, e)
self.s3.invalidate_cache(self.bucket)
self.s3.invalidate_cache(self.path)
self.closed = True

def readable(self):
Expand Down
7 changes: 7 additions & 0 deletions s3fs/tests/test_mapping.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
from s3fs.tests.test_s3fs import s3, test_bucket_name
from s3fs import S3Map, S3FileSystem

Expand All @@ -19,6 +20,12 @@ def test_default_s3filesystem(s3):
assert d.s3 is s3


def test_errors(s3):
d = S3Map(root, s3)
with pytest.raises(KeyError):
d['nonexistent']


def test_with_data(s3):
d = S3Map(root, s3)
d['x'] = b'123'
Expand Down
22 changes: 17 additions & 5 deletions s3fs/tests/test_s3fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,16 @@ def test_multiple_objects(s3):
assert s3.ls('test') == s32.ls('test')


def test_info(s3):
s3.touch(a)
s3.touch(b)
assert s3.info(a) == s3.ls(a, detail=True)[0]
parent = a.rsplit('/', 1)[0]
s3.dirs[parent].pop(0) # disappear our file!
assert a not in s3.ls(parent)
assert s3.info(a) # now uses head_object


@pytest.mark.xfail()
def test_delegate(s3):
out = s3.get_delegated_s3pars()
Expand All @@ -126,8 +136,6 @@ def test_ls(s3):
s3.ls('nonexistent')
fn = test_bucket_name+'/test/accounts.1.json'
assert fn in s3.ls(test_bucket_name+'/test')
# assert fn in s3.ls(test_bucket_name)
# assert [fn] == s3.ls(fn)


def test_pickle(s3):
Expand All @@ -137,7 +145,7 @@ def test_pickle(s3):


def test_ls_touch(s3):
assert not s3.ls(test_bucket_name+'/tmp/test')
assert not s3.exists(test_bucket_name+'/tmp/test')
s3.touch(a)
s3.touch(b)
L = s3.ls(test_bucket_name+'/tmp/test', True)
Expand All @@ -161,8 +169,7 @@ def test_rm(s3):

#whole bucket
s3.rm(test_bucket_name, recursive=True)
with pytest.raises((IOError, OSError)):
s3.exists(test_bucket_name+'/2014-01-01.csv')
assert not s3.exists(test_bucket_name+'/2014-01-01.csv')
assert not s3.exists(test_bucket_name)


Expand Down Expand Up @@ -357,6 +364,11 @@ def test_errors(s3):
with pytest.raises((IOError, OSError)):
s3.mkdir('/')

with pytest.raises(ValueError):
s3.walk('')

with pytest.raises(ValueError):
s3.walk('s3://')

def test_read_small(s3):
fn = test_bucket_name+'/2014-01-01.csv'
Expand Down