Skip to content

Commit c68b256

Browse files
committed
Allow delayed shapefile loading by passing no args, and add to tests
See #195
1 parent 050b62f commit c68b256

File tree

2 files changed

+26
-39
lines changed

2 files changed

+26
-39
lines changed

shapefile.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -802,11 +802,12 @@ def __init__(self, *args, **kwargs):
802802
self.__fieldposition_lookup = {}
803803
self.encoding = kwargs.pop('encoding', 'utf-8')
804804
self.encodingErrors = kwargs.pop('encodingErrors', 'strict')
805-
# See if a shapefile name was passed as an argument
805+
# See if a shapefile name was passed as the first argument
806806
if len(args) > 0:
807807
if is_string(args[0]):
808808
self.load(args[0])
809809
return
810+
# Otherwise, load from separate shp/shx/dbf args (must be file-like)
810811
if "shp" in kwargs.keys():
811812
if hasattr(kwargs["shp"], "read"):
812813
self.shp = kwargs["shp"]
@@ -815,6 +816,9 @@ def __init__(self, *args, **kwargs):
815816
self.shp.seek(0)
816817
except (NameError, io.UnsupportedOperation):
817818
self.shp = io.BytesIO(self.shp.read())
819+
else:
820+
raise ShapefileException('The shp arg must be file-like.')
821+
818822
if "shx" in kwargs.keys():
819823
if hasattr(kwargs["shx"], "read"):
820824
self.shx = kwargs["shx"]
@@ -823,6 +827,9 @@ def __init__(self, *args, **kwargs):
823827
self.shx.seek(0)
824828
except (NameError, io.UnsupportedOperation):
825829
self.shx = io.BytesIO(self.shx.read())
830+
else:
831+
raise ShapefileException('The shx arg must be file-like.')
832+
826833
if "dbf" in kwargs.keys():
827834
if hasattr(kwargs["dbf"], "read"):
828835
self.dbf = kwargs["dbf"]
@@ -831,10 +838,12 @@ def __init__(self, *args, **kwargs):
831838
self.dbf.seek(0)
832839
except (NameError, io.UnsupportedOperation):
833840
self.dbf = io.BytesIO(self.dbf.read())
841+
else:
842+
raise ShapefileException('The dbf arg must be file-like.')
843+
844+
# Load the files
834845
if self.shp or self.dbf:
835846
self.load()
836-
else:
837-
raise ShapefileException("Shapefile Reader requires a shapefile or file-like object.")
838847

839848
def __str__(self):
840849
"""

test_shapefile.py

Lines changed: 14 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -291,42 +291,6 @@ def test_reader_shapefile_extension_ignored():
291291
assert not os.path.exists(filename)
292292

293293

294-
def test_reader_dbf_only():
295-
"""
296-
Assert that specifying just the
297-
dbf argument to the shapefile reader
298-
reads just the dbf file.
299-
"""
300-
with shapefile.Reader(dbf="shapefiles/blockgroups.dbf") as sf:
301-
assert len(sf) == 663
302-
record = sf.record(3)
303-
assert record[1:3] == ['060750601001', 4715]
304-
305-
306-
def test_reader_shp_shx_only():
307-
"""
308-
Assert that specifying just the
309-
shp and shx argument to the shapefile reader
310-
reads just the shp and shx file.
311-
"""
312-
with shapefile.Reader(shp="shapefiles/blockgroups.shp", shx="shapefiles/blockgroups.shx") as sf:
313-
assert len(sf) == 663
314-
shape = sf.shape(3)
315-
assert len(shape.points) is 173
316-
317-
318-
def test_reader_shx_optional():
319-
"""
320-
Assert that specifying just the
321-
shp argument to the shapefile reader
322-
reads just the shp file (shx optional).
323-
"""
324-
with shapefile.Reader(shp="shapefiles/blockgroups.shp") as sf:
325-
assert len(sf) == 663
326-
shape = sf.shape(3)
327-
assert len(shape.points) is 173
328-
329-
330294
def test_reader_filelike_dbf_only():
331295
"""
332296
Assert that specifying just the
@@ -363,6 +327,20 @@ def test_reader_filelike_shx_optional():
363327
assert len(shape.points) is 173
364328

365329

330+
def test_reader_shapefile_delayed_load():
331+
"""
332+
Assert that the filename's extension is
333+
ignored when reading a shapefile.
334+
"""
335+
with shapefile.Reader() as sf:
336+
# assert that data request raises exception, since no file has been provided yet
337+
with pytest.raises(shapefile.ShapefileException):
338+
sf.shape(0)
339+
# assert that works after loading file manually
340+
sf.load("shapefiles/blockgroups")
341+
assert len(sf) == 663
342+
343+
366344
def test_records_match_shapes():
367345
"""
368346
Assert that the number of records matches

0 commit comments

Comments
 (0)