Skip to content

Commit a6100a0

Browse files
committed
Rewrite, and some tests
1 parent 6f173bc commit a6100a0

File tree

4 files changed

+175
-113
lines changed

4 files changed

+175
-113
lines changed

conda-requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ pyke
1111
udunits2
1212
cf_units
1313
dask
14+
distributed
1415

1516
# Iris build dependencies
1617
setuptools

lib/iris/options.py

Lines changed: 133 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
"""
2121
from __future__ import (absolute_import, division, print_function)
2222
from six.moves import (filter, input, map, range, zip) # noqa
23+
import six
2324

25+
import contextlib
2426
from multiprocessing import cpu_count
2527
from multiprocessing.pool import ThreadPool
2628
import re
@@ -31,12 +33,45 @@
3133
import distributed
3234

3335

34-
class Parallel(object):
36+
class Option(object):
37+
"""
38+
An abstract superclass to enforce certain key behaviours for all `Option`
39+
classes.
40+
41+
"""
42+
@property
43+
def _defaults_dict(self):
44+
raise NotImplementedError
45+
46+
def __setattr__(self, name, value):
47+
if name not in self.__dict__:
48+
# Can't add new names.
49+
msg = "'Option' object has no attribute {!r}".format(name)
50+
raise AttributeError(msg)
51+
if value is None:
52+
# Set an explicitly unset value to the default value for the name.
53+
value = self._defaults_dict[name]['default']
54+
if self._defaults_dict[name]['options'] is not None:
55+
# Replace a bad value with the default if there is a defined set of
56+
# specified good values.
57+
if value not in self._defaults_dict[name]['options']:
58+
good_value = self._defaults_dict[name]['default']
59+
wmsg = ('Attempting to set bad value {!r} for attribute {!r}. '
60+
'Defaulting to {!r}.')
61+
warnings.warn(wmsg.format(value, name, good_value))
62+
value = good_value
63+
self.__dict__[name] = value
64+
65+
def context(self):
66+
raise NotImplementedError
67+
68+
69+
class Parallel(Option):
3570
"""
3671
Control dask parallel processing options for Iris.
3772
3873
"""
39-
def __init__(self, scheduler='threaded', num_workers=1):
74+
def __init__(self, scheduler=None, num_workers=None):
4075
"""
4176
Set up options for dask parallel processing.
4277
@@ -89,96 +124,115 @@ def __init__(self, scheduler='threaded', num_workers=1):
89124
* Specify that we want to load a cube with dask parallel processing
90125
using multiprocessing with six worker processes::
91126
92-
>>> iris.options.parallel(scheduler='multiprocessing', num_workers=6)
93-
>>> iris.load('my_dataset.nc')
127+
iris.options.parallel(scheduler='multiprocessing', num_workers=6)
128+
iris.load('my_dataset.nc')
94129
95130
* Specify, with a context manager, that we want to load a cube with
96131
dask parallel processing using four worker threads::
97132
98-
>>> with iris.options.parallel(scheduler='threaded', num_workers=4):
99-
... iris.load('my_dataset.nc')
133+
with iris.options.parallel(scheduler='threaded', num_workers=4):
134+
iris.load('my_dataset.nc')
100135
101136
* Run dask parallel processing using a distributed scheduler that has
102137
been set up at the IP address and port at ``192.168.0.219:8786``::
103138
104-
>>> iris.options.parallel(scheduler='192.168.0.219:8786')
139+
iris.options.parallel(scheduler='192.168.0.219:8786')
105140
106141
"""
107-
# Set some defaults first of all.
108-
self._default_scheduler = 'threaded'
109-
self._default_num_workers = 1
142+
# Set `__dict__` keys first.
143+
self.__dict__['_scheduler'] = scheduler
144+
self.__dict__['scheduler'] = None
145+
self.__dict__['num_workers'] = None
146+
self.__dict__['dask_scheduler'] = None
110147

111-
self.scheduler = scheduler
112-
self.num_workers = num_workers
113-
114-
self._dask_scheduler = None
148+
# Set `__dict__` values for each kwarg.
149+
setattr(self, 'scheduler', scheduler)
150+
setattr(self, 'num_workers', num_workers)
151+
setattr(self, 'dask_scheduler', self.get('scheduler'))
115152

116153
# Activate the specified dask options.
117154
self._set_dask_options()
118155

156+
def __setattr__(self, name, value):
157+
if value is None:
158+
value = self._defaults_dict[name]['default']
159+
attr_setter = getattr(self, 'set_{}'.format(name))
160+
value = attr_setter(value)
161+
super(Parallel, self).__setattr__(name, value)
162+
119163
@property
120-
def scheduler(self):
121-
return self._scheduler
164+
def _defaults_dict(self):
165+
"""
166+
Define the default value and available options for each settable
167+
`kwarg` of this `Option`.
168+
169+
Note: `'options'` can be set to `None` if it is not reasonable to
170+
specify all possible options. For example, this may be reasonable if
171+
the `'options'` were a range of numbers.
122172
123-
@scheduler.setter
124-
def scheduler(self, value):
173+
"""
174+
return {'_scheduler': {'default': None, 'options': None},
175+
'scheduler': {'default': 'threaded',
176+
'options': ['threaded',
177+
'multiprocessing',
178+
'async',
179+
'distributed']},
180+
'num_workers': {'default': 1, 'options': None},
181+
'dask_scheduler': {'default': None, 'options': None},
182+
}
183+
184+
def set__scheduler(self, value):
185+
return value
186+
187+
def set_scheduler(self, value):
188+
default = self._defaults_dict['scheduler']['default']
125189
if value is None:
126-
value = self._default_scheduler
127-
if value == 'threaded':
128-
self._scheduler = value
129-
self.dask_scheduler = dask.threaded.get
130-
elif value == 'multiprocessing':
131-
self._scheduler = value
132-
self.dask_scheduler = dask.multiprocessing.get
133-
elif value == 'async':
134-
self._scheduler = value
135-
self.dask_scheduler = dask.async.get_sync
190+
value = default
136191
elif re.match(r'^(\d{1,3}\.){3}\d{1,3}:\d{1,5}$', value):
137-
self._scheduler = 'distributed'
138-
self.dask_scheduler = value
139-
else:
192+
value = 'distributed'
193+
elif value not in self._defaults_dict['scheduler']['options']:
140194
# Invalid value for `scheduler`.
141195
wmsg = 'Invalid value for scheduler: {!r}. Defaulting to {}.'
142-
warnings.warn(wmsg.format(value, self._default_scheduler))
143-
self.scheduler = self._default_scheduler
144-
145-
@property
146-
def num_workers(self):
147-
return self._num_workers
148-
149-
@num_workers.setter
150-
def num_workers(self, value):
151-
if self.scheduler == 'async' and value != self._default_num_workers:
196+
warnings.warn(wmsg.format(value, default))
197+
self.set_scheduler(default)
198+
return value
199+
200+
def set_num_workers(self, value):
201+
default = self._defaults_dict['num_workers']['default']
202+
scheduler = self.get('scheduler')
203+
if scheduler == 'async' and value != default:
152204
wmsg = 'Cannot set `num_workers` for the serial scheduler {!r}.'
153-
warnings.warn(wmsg.format(self.scheduler))
205+
warnings.warn(wmsg.format(scheduler))
154206
value = None
155-
elif (self.scheduler == 'distributed' and
156-
value != self._default_num_workers):
207+
elif scheduler == 'distributed' and value != default:
157208
wmsg = ('Attempting to set `num_workers` with the {!r} scheduler '
158209
'requested. Please instead specify number of workers when '
159210
'setting up the distributed scheduler. See '
160211
'https://distributed.readthedocs.io/en/latest/index.html '
161212
'for more details.')
162-
warnings.warn(wmsg.format(self.scheduler))
213+
warnings.warn(wmsg.format(scheduler))
163214
value = None
164215
else:
165216
if value is None:
166-
value = self._default_num_workers
217+
value = default
167218
if value >= cpu_count():
168219
# Limit maximum CPUs used to 1 fewer than all available CPUs.
169220
wmsg = ('Requested more CPUs ({}) than total available ({}). '
170221
'Limiting number of used CPUs to {}.')
171222
warnings.warn(wmsg.format(value, cpu_count(), cpu_count()-1))
172223
value = cpu_count() - 1
173-
self._num_workers = value
174-
175-
@property
176-
def dask_scheduler(self):
177-
return self._dask_scheduler
178-
179-
@dask_scheduler.setter
180-
def dask_scheduler(self, value):
181-
self._dask_scheduler = value
224+
return value
225+
226+
def set_dask_scheduler(self, scheduler):
227+
if scheduler == 'threaded':
228+
value = dask.threaded.get
229+
elif scheduler == 'multiprocessing':
230+
value = dask.multiprocessing.get
231+
elif scheduler == 'async':
232+
value = dask.async.get_sync
233+
elif scheduler == 'distributed':
234+
value = self.get('_scheduler')
235+
return value
182236

183237
def _set_dask_options(self):
184238
"""
@@ -187,25 +241,37 @@ def _set_dask_options(self):
187241
context manager.
188242
189243
"""
190-
get = self.dask_scheduler
244+
scheduler = self.get('scheduler')
245+
num_workers = self.get('num_workers')
246+
get = self.get('dask_scheduler')
191247
pool = None
192-
if self.scheduler in ['threaded', 'multiprocessing']:
193-
pool = ThreadPool(self.num_workers)
194-
if self.scheduler == 'distributed':
195-
get = distributed.Client(self.dask_scheduler).get
248+
249+
if scheduler in ['threaded', 'multiprocessing']:
250+
pool = ThreadPool(num_workers)
251+
if scheduler == 'distributed':
252+
get = distributed.Client(get).get
196253

197254
dask.set_options(get=get, pool=pool)
198255

199256
def get(self, item):
200257
return getattr(self, item)
201258

202-
def __enter__(self):
203-
return
204-
205-
def __exit__(self, exception_type, exception_value, exception_traceback):
206-
self.num_workers = self._default_num_workers
207-
self.scheduler = self._default_scheduler
259+
@contextlib.contextmanager
260+
def context(self, **kwargs):
261+
# Snapshot the starting state for restoration at the end of the
262+
# contextmanager block.
263+
starting_state = self.__dict__.copy()
264+
# Update the state to reflect the requested changes.
265+
for name, value in six.iteritems(kwargs):
266+
setattr(self, name, value)
208267
self._set_dask_options()
268+
try:
269+
yield
270+
finally:
271+
# Return the state to the starting state.
272+
self.__dict__.clear()
273+
self.__dict__.update(starting_state)
274+
self._set_dask_options()
209275

210276

211-
parallel = Parallel
277+
parallel = Parallel()

0 commit comments

Comments
 (0)