22import warnings
33from collections import OrderedDict
44from distutils .version import LooseVersion
5-
65import numpy as np
76
87from .. import DataArray
2423class RasterioArrayWrapper (BackendArray ):
2524 """A wrapper around rasterio dataset objects"""
2625
27- def __init__ (self , manager , lock ):
26+ def __init__ (self , manager , lock , vrt_params = None ):
27+ from rasterio .vrt import WarpedVRT
2828 self .manager = manager
2929 self .lock = lock
3030
3131 # cannot save riods as an attribute: this would break pickleability
3232 riods = manager .acquire ()
33-
33+ if vrt_params is not None :
34+ riods = WarpedVRT (riods , ** vrt_params )
35+ self .vrt_params = vrt_params
3436 self ._shape = (riods .count , riods .height , riods .width )
3537
3638 dtypes = riods .dtypes
@@ -104,6 +106,7 @@ def _get_indexer(self, key):
104106 return band_key , tuple (window ), tuple (squeeze_axis ), tuple (np_inds )
105107
106108 def _getitem (self , key ):
109+ from rasterio .vrt import WarpedVRT
107110 band_key , window , squeeze_axis , np_inds = self ._get_indexer (key )
108111
109112 if not band_key or any (start == stop for (start , stop ) in window ):
@@ -114,6 +117,8 @@ def _getitem(self, key):
114117 else :
115118 with self .lock :
116119 riods = self .manager .acquire (needs_lock = False )
120+ if self .vrt_params is not None :
121+ riods = WarpedVRT (riods , ** self .vrt_params )
117122 out = riods .read (band_key , window = window )
118123
119124 if squeeze_axis :
@@ -178,8 +183,8 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None,
178183
179184 Parameters
180185 ----------
181- filename : str
182- Path to the file to open.
186+ filename : str, rasterio.DatasetReader, or rasterio.WarpedVRT
187+ Path to the file to open. Or already open rasterio dataset.
183188 parse_coordinates : bool, optional
184189 Whether to parse the x and y coordinates out of the file's
185190 ``transform`` attribute or not. The default is to automatically
@@ -206,14 +211,28 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None,
206211 data : DataArray
207212 The newly created DataArray.
208213 """
209-
210214 import rasterio
215+ from rasterio .vrt import WarpedVRT
216+ vrt_params = None
217+ if isinstance (filename , rasterio .io .DatasetReader ):
218+ filename = filename .name
219+ elif isinstance (filename , rasterio .vrt .WarpedVRT ):
220+ vrt = filename
221+ filename = vrt .src_dataset .name
222+ vrt_params = dict (crs = vrt .crs .to_string (),
223+ resampling = vrt .resampling ,
224+ src_nodata = vrt .src_nodata ,
225+ dst_nodata = vrt .dst_nodata ,
226+ tolerance = vrt .tolerance ,
227+ warp_extras = vrt .warp_extras )
211228
212229 if lock is None :
213230 lock = RASTERIO_LOCK
214231
215232 manager = CachingFileManager (rasterio .open , filename , lock = lock , mode = 'r' )
216233 riods = manager .acquire ()
234+ if vrt_params is not None :
235+ riods = WarpedVRT (riods , ** vrt_params )
217236
218237 if cache is None :
219238 cache = chunks is None
@@ -287,14 +306,14 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None,
287306 for k , v in meta .items ():
288307 # Add values as coordinates if they match the band count,
289308 # as attributes otherwise
290- if (isinstance (v , (list , np .ndarray )) and
291- len (v ) == riods .count ):
309+ if (isinstance (v , (list , np .ndarray ))
310+ and len (v ) == riods .count ):
292311 coords [k ] = ('band' , np .asarray (v ))
293312 else :
294313 attrs [k ] = v
295314
296315 data = indexing .LazilyOuterIndexedArray (
297- RasterioArrayWrapper (manager , lock ))
316+ RasterioArrayWrapper (manager , lock , vrt_params ))
298317
299318 # this lets you write arrays loaded with rasterio
300319 data = indexing .CopyOnWriteArray (data )
0 commit comments