diff --git a/.gitignore b/.gitignore index b0aecc908..77eb0e8f4 100644 --- a/.gitignore +++ b/.gitignore @@ -21,3 +21,4 @@ __pycache__ .coverage *coverage.xml .coverage.* +.idea diff --git a/ipyparallel/apps/ipcontrollerapp.py b/ipyparallel/apps/ipcontrollerapp.py index 2bab5cb4f..bfe6506b4 100755 --- a/ipyparallel/apps/ipcontrollerapp.py +++ b/ipyparallel/apps/ipcontrollerapp.py @@ -33,13 +33,14 @@ from ipython_genutils.importstring import import_item from traitlets import Unicode, Bool, List, Dict, TraitError, observe -from jupyter_client.session import ( - Session, session_aliases, session_flags, -) +from jupyter_client.session import Session, session_aliases, session_flags +from ipyparallel.controller.broadcast_scheduler import launch_broadcast_scheduler, \ + BroadcastScheduler from ipyparallel.controller.heartmonitor import HeartMonitor from ipyparallel.controller.hub import HubFactory -from ipyparallel.controller.scheduler import TaskScheduler,launch_scheduler +from ipyparallel.controller.scheduler import launch_scheduler +from ipyparallel.controller.task_scheduler import TaskScheduler from ipyparallel.controller.dictdb import DictDB from ipyparallel.util import disambiguate_url @@ -62,10 +63,9 @@ real_dbs.append(MongoDB) - -#----------------------------------------------------------------------------- +# ----------------------------------------------------------------------------- # Module level variables -#----------------------------------------------------------------------------- +# ----------------------------------------------------------------------------- _description = """Start the IPython controller for parallel computing. @@ -84,101 +84,141 @@ """ -#----------------------------------------------------------------------------- +# ----------------------------------------------------------------------------- # The main application -#----------------------------------------------------------------------------- +# ----------------------------------------------------------------------------- flags = {} flags.update(base_flags) -flags.update({ - 'usethreads' : ( {'IPControllerApp' : {'use_threads' : True}}, - 'Use threads instead of processes for the schedulers'), - 'sqlitedb' : ({'HubFactory' : {'db_class' : 'ipyparallel.controller.sqlitedb.SQLiteDB'}}, - 'use the SQLiteDB backend'), - 'mongodb' : ({'HubFactory' : {'db_class' : 'ipyparallel.controller.mongodb.MongoDB'}}, - 'use the MongoDB backend'), - 'dictdb' : ({'HubFactory' : {'db_class' : 'ipyparallel.controller.dictdb.DictDB'}}, - 'use the in-memory DictDB backend'), - 'nodb' : ({'HubFactory' : {'db_class' : 'ipyparallel.controller.dictdb.NoDB'}}, - """use dummy DB backend, which doesn't store any information. +flags.update( + { + 'usethreads': ( + {'IPControllerApp': {'use_threads': True}}, + 'Use threads instead of processes for the schedulers', + ), + 'sqlitedb': ( + {'HubFactory': {'db_class': 'ipyparallel.controller.sqlitedb.SQLiteDB'}}, + 'use the SQLiteDB backend', + ), + 'mongodb': ( + {'HubFactory': {'db_class': 'ipyparallel.controller.mongodb.MongoDB'}}, + 'use the MongoDB backend', + ), + 'dictdb': ( + {'HubFactory': {'db_class': 'ipyparallel.controller.dictdb.DictDB'}}, + 'use the in-memory DictDB backend', + ), + 'nodb': ( + {'HubFactory': {'db_class': 'ipyparallel.controller.dictdb.NoDB'}}, + """use dummy DB backend, which doesn't store any information. This is the default as of IPython 0.13. To enable delayed or repeated retrieval of results from the Hub, select one of the true db backends. - """), - 'reuse' : ({'IPControllerApp' : {'reuse_files' : True}}, - 'reuse existing json connection files'), - 'restore' : ({'IPControllerApp' : {'restore_engines' : True, 'reuse_files' : True}}, - 'Attempt to restore engines from a JSON file. ' - 'For use when resuming a crashed controller'), -}) + """, + ), + 'reuse': ( + {'IPControllerApp': {'reuse_files': True}}, + 'reuse existing json connection files', + ), + 'restore': ( + {'IPControllerApp': {'restore_engines': True, 'reuse_files': True}}, + 'Attempt to restore engines from a JSON file. ' + 'For use when resuming a crashed controller', + ), + } +) flags.update(session_flags) aliases = dict( - ssh = 'IPControllerApp.ssh_server', - enginessh = 'IPControllerApp.engine_ssh_server', - location = 'IPControllerApp.location', - - url = 'HubFactory.url', - ip = 'HubFactory.ip', - transport = 'HubFactory.transport', - port = 'HubFactory.regport', - - ping = 'HeartMonitor.period', - - scheme = 'TaskScheduler.scheme_name', - hwm = 'TaskScheduler.hwm', + ssh='IPControllerApp.ssh_server', + enginessh='IPControllerApp.engine_ssh_server', + location='IPControllerApp.location', + url='HubFactory.url', + ip='HubFactory.ip', + transport='HubFactory.transport', + port='HubFactory.regport', + ping='HeartMonitor.period', + scheme='TaskScheduler.scheme_name', + hwm='TaskScheduler.hwm', ) aliases.update(base_aliases) aliases.update(session_aliases) + class IPControllerApp(BaseParallelApplication): name = u'ipcontroller' description = _description examples = _examples - classes = [ProfileDir, Session, HubFactory, TaskScheduler, HeartMonitor, DictDB] + real_dbs - + classes = [ + ProfileDir, + Session, + HubFactory, + TaskScheduler, + HeartMonitor, + DictDB, + ] + real_dbs + # change default to True - auto_create = Bool(True, config=True, - help="""Whether to create profile dir if it doesn't exist.""") - - reuse_files = Bool(False, config=True, + auto_create = Bool( + True, config=True, help="""Whether to create profile dir if it doesn't exist.""" + ) + + reuse_files = Bool( + False, + config=True, help="""Whether to reuse existing json connection files. If False, connection files will be removed on a clean exit. - """ + """, ) - restore_engines = Bool(False, config=True, + restore_engines = Bool( + False, + config=True, help="""Reload engine state from JSON file - """ + """, ) - ssh_server = Unicode(u'', config=True, + ssh_server = Unicode( + u'', + config=True, help="""ssh url for clients to use when connecting to the Controller processes. It should be of the form: [user@]server[:port]. The Controller's listening addresses must be accessible from the ssh server""", ) - engine_ssh_server = Unicode(u'', config=True, + engine_ssh_server = Unicode( + u'', + config=True, help="""ssh url for engines to use when connecting to the Controller processes. It should be of the form: [user@]server[:port]. The Controller's listening addresses must be accessible from the ssh server""", ) - location = Unicode(socket.gethostname(), config=True, + location = Unicode( + socket.gethostname(), + config=True, help="""The external IP or domain name of the Controller, used for disambiguating engine and client connections.""", ) - import_statements = List([], config=True, - help="import statements to be run at startup. Necessary in some environments" + import_statements = List( + [], + config=True, + help="import statements to be run at startup. Necessary in some environments", ) - use_threads = Bool(False, config=True, - help='Use threads instead of processes for the schedulers', + use_threads = Bool( + False, config=True, help='Use threads instead of processes for the schedulers' ) - engine_json_file = Unicode('ipcontroller-engine.json', config=True, - help="JSON filename where engine connection info will be stored.") - client_json_file = Unicode('ipcontroller-client.json', config=True, - help="JSON filename where client connection info will be stored.") + engine_json_file = Unicode( + 'ipcontroller-engine.json', + config=True, + help="JSON filename where engine connection info will be stored.", + ) + client_json_file = Unicode( + 'ipcontroller-client.json', + config=True, + help="JSON filename where client connection info will be stored.", + ) @observe('cluster_id') def _cluster_id_changed(self, change): @@ -186,7 +226,6 @@ def _cluster_id_changed(self, change): self.engine_json_file = "%s-engine.json" % self.name self.client_json_file = "%s-client.json" % self.name - # internal children = List() mq_class = Unicode('zmq.devices.ProcessMonitoredQueue') @@ -196,16 +235,16 @@ def _use_threads_changed(self, change): self.mq_class = 'zmq.devices.{}MonitoredQueue'.format( 'Thread' if change['new'] else 'Process' ) - - write_connection_files = Bool(True, + + write_connection_files = Bool( + True, help="""Whether to write connection files to disk. True in all cases other than runs with `reuse_files=True` *after the first* - """ + """, ) aliases = Dict(aliases) flags = Dict(flags) - def save_connection_dict(self, fname, cdict): """save a connection dict to json file.""" @@ -213,49 +252,51 @@ def save_connection_dict(self, fname, cdict): self.log.info("writing connection info to %s", fname) with open(fname, 'w') as f: f.write(json.dumps(cdict, indent=2)) - os.chmod(fname, stat.S_IRUSR|stat.S_IWUSR) - + os.chmod(fname, stat.S_IRUSR | stat.S_IWUSR) + def load_config_from_json(self): """load config from existing json connector files.""" c = self.config self.log.debug("loading config from JSON") - + # load engine config - + fname = os.path.join(self.profile_dir.security_dir, self.engine_json_file) self.log.info("loading connection info from %s", fname) with open(fname) as f: ecfg = json.loads(f.read()) - + # json gives unicode, Session.key wants bytes c.Session.key = ecfg['key'].encode('ascii') - - xport,ip = ecfg['interface'].split('://') - + + xport, ip = ecfg['interface'].split('://') + c.HubFactory.engine_ip = ip c.HubFactory.engine_transport = xport - + self.location = ecfg['location'] if not self.engine_ssh_server: self.engine_ssh_server = ecfg['ssh'] - + # load client config - + fname = os.path.join(self.profile_dir.security_dir, self.client_json_file) self.log.info("loading connection info from %s", fname) with open(fname) as f: ccfg = json.loads(f.read()) - + for key in ('key', 'registration', 'pack', 'unpack', 'signature_scheme'): - assert ccfg[key] == ecfg[key], "mismatch between engine and client info: %r" % key - + assert ccfg[key] == ecfg[key], ( + "mismatch between engine and client info: %r" % key + ) + xport, ip = ccfg['interface'].split('://') - + c.HubFactory.client_transport = xport c.HubFactory.client_ip = ip if not self.ssh_server: self.ssh_server = ccfg['ssh'] - + # load port config: c.HubFactory.regport = ecfg['registration'] c.HubFactory.hb = (ecfg['hb_ping'], ecfg['hb_pong']) @@ -264,7 +305,7 @@ def load_config_from_json(self): c.HubFactory.task = (ccfg['task'], ecfg['task']) c.HubFactory.iopub = (ccfg['iopub'], ecfg['iopub']) c.HubFactory.notifier_port = ccfg['notification'] - + def cleanup_connection_files(self): if self.reuse_files: self.log.debug("leaving JSON connection files for reuse") @@ -278,29 +319,32 @@ def cleanup_connection_files(self): self.log.error("Failed to cleanup connection file: %s", e) else: self.log.debug(u"removed %s", f) - + def load_secondary_config(self): """secondary config, loading from JSON and setting defaults""" if self.reuse_files: try: self.load_config_from_json() - except (AssertionError,IOError) as e: + except (AssertionError, IOError) as e: self.log.error("Could not load config from JSON: %s" % e) else: # successfully loaded config from JSON, and reuse=True # no need to wite back the same file self.write_connection_files = False - + self.log.debug("Config changed") self.log.debug(repr(self.config)) - + def init_hub(self): c = self.config - + self.do_import_statements() - + try: - self.factory = HubFactory(config=c, log=self.log) + self.factory = HubFactory( + config=c, + log=self.log, + ) # self.start_logging() self.factory.init_hub() except TraitError: @@ -308,64 +352,77 @@ def init_hub(self): except Exception: self.log.error("Couldn't construct the Controller", exc_info=True) self.exit(1) - + if self.write_connection_files: # save to new json config files f = self.factory base = { - 'key' : f.session.key.decode('ascii'), - 'location' : self.location, - 'pack' : f.session.packer, - 'unpack' : f.session.unpacker, - 'signature_scheme' : f.session.signature_scheme, + 'key': f.session.key.decode('ascii'), + 'location': self.location, + 'pack': f.session.packer, + 'unpack': f.session.unpacker, + 'signature_scheme': f.session.signature_scheme, } - - cdict = {'ssh' : self.ssh_server} + + cdict = {'ssh': self.ssh_server} cdict.update(f.client_info) cdict.update(base) self.save_connection_dict(self.client_json_file, cdict) - - edict = {'ssh' : self.engine_ssh_server} + + edict = {'ssh': self.engine_ssh_server} edict.update(f.engine_info) edict.update(base) self.save_connection_dict(self.engine_json_file, edict) fname = "engines%s.json" % self.cluster_id - self.factory.hub.engine_state_file = os.path.join(self.profile_dir.log_dir, fname) + self.factory.hub.engine_state_file = os.path.join( + self.profile_dir.log_dir, fname + ) if self.restore_engines: self.factory.hub._load_engine_state() # load key into config so other sessions in this process (TaskScheduler) # have the same value self.config.Session.key = self.factory.session.key + def launch_python_scheduler(self, scheduler_args, children): + if 'Process' in self.mq_class: + # run the Python scheduler in a Process + q = Process(target=launch_scheduler, kwargs=scheduler_args) + q.daemon = True + children.append(q) + else: + # single-threaded Controller + scheduler_args['in_thread'] = True + launch_scheduler(**scheduler_args) + def init_schedulers(self): children = self.children mq = import_item(str(self.mq_class)) - + f = self.factory ident = f.session.bsession # disambiguate url, in case of * monitor_url = disambiguate_url(f.monitor_url) # maybe_inproc = 'inproc://monitor' if self.use_threads else monitor_url # IOPub relay (in a Process) - q = mq(zmq.PUB, zmq.SUB, zmq.PUB, b'N/A',b'iopub') + q = mq(zmq.PUB, zmq.SUB, zmq.PUB, b'N/A', b'iopub') q.bind_in(f.client_url('iopub')) q.setsockopt_in(zmq.IDENTITY, ident + b"_iopub") q.bind_out(f.engine_url('iopub')) q.setsockopt_out(zmq.SUBSCRIBE, b'') q.connect_mon(monitor_url) - q.daemon=True + q.daemon = True children.append(q) # Multiplexer Queue (in a Process) q = mq(zmq.ROUTER, zmq.ROUTER, zmq.PUB, b'in', b'out') - + q.bind_in(f.client_url('mux')) q.setsockopt_in(zmq.IDENTITY, b'mux_in') q.bind_out(f.engine_url('mux')) q.setsockopt_out(zmq.IDENTITY, b'mux_out') q.connect_mon(monitor_url) - q.daemon=True + q.daemon = True children.append(q) # Control Queue (in a Process) @@ -375,7 +432,7 @@ def init_schedulers(self): q.bind_out(f.engine_url('control')) q.setsockopt_out(zmq.IDENTITY, b'control_out') q.connect_mon(monitor_url) - q.daemon=True + q.daemon = True children.append(q) if 'TaskScheduler.scheme_name' in self.config: scheme = self.config.TaskScheduler.scheme_name @@ -391,35 +448,28 @@ def init_schedulers(self): q.bind_out(f.engine_url('task')) q.setsockopt_out(zmq.IDENTITY, b'task_out') q.connect_mon(monitor_url) - q.daemon=True + q.daemon = True children.append(q) elif scheme == 'none': self.log.warn("task::using no Task scheduler") else: - self.log.info("task::using Python %s Task scheduler"%scheme) - sargs = (f.client_url('task'), f.engine_url('task'), - monitor_url, disambiguate_url(f.client_url('notification')), - disambiguate_url(f.client_url('registration')), + self.log.info("task::using Python %s Task scheduler" % scheme) + self.launch_python_scheduler( + self.get_python_scheduler_args('task', f, TaskScheduler, monitor_url), + children, ) - kwargs = dict(logname='scheduler', loglevel=self.log_level, - log_url = self.log_url, config=dict(self.config)) - if 'Process' in self.mq_class: - # run the Python scheduler in a Process - q = Process(target=launch_scheduler, args=sargs, kwargs=kwargs) - q.daemon=True - children.append(q) - else: - # single-threaded Controller - kwargs['in_thread'] = True - launch_scheduler(*sargs, **kwargs) - + + self.launch_broadcast_schedulers( + f, monitor_url, children + ) + # set unlimited HWM for all relay devices if hasattr(zmq, 'SNDHWM'): q = children[0] q.setsockopt_in(zmq.RCVHWM, 0) q.setsockopt_out(zmq.SNDHWM, 0) - + for q in children[1:]: if not hasattr(q, 'setsockopt_in'): continue @@ -428,7 +478,6 @@ def init_schedulers(self): q.setsockopt_out(zmq.SNDHWM, 0) q.setsockopt_out(zmq.RCVHWM, 0) q.setsockopt_mon(zmq.SNDHWM, 0) - def terminate_children(self): child_procs = [] @@ -445,12 +494,12 @@ def terminate_children(self): except OSError: # already dead pass - + def handle_signal(self, sig, frame): self.log.critical("Received signal %i, shutting down", sig) self.terminate_children() self.loop.stop() - + def init_signal(self): for sig in (SIGINT, SIGABRT, SIGTERM): signal(sig, self.handle_signal) @@ -466,7 +515,7 @@ def do_import_statements(self): def forward_logging(self): if self.log_url: - self.log.info("Forwarding logging to %s"%self.log_url) + self.log.info("Forwarding logging to %s" % self.log_url) context = zmq.Context.instance() lsock = context.socket(zmq.PUB) lsock.connect(self.log_url) @@ -474,7 +523,7 @@ def forward_logging(self): handler.root_topic = 'controller' handler.setLevel(self.log_level) self.log.addHandler(handler) - + @catch_config_error def initialize(self, argv=None): super(IPControllerApp, self).initialize(argv) @@ -482,7 +531,7 @@ def initialize(self, argv=None): self.load_secondary_config() self.init_hub() self.init_schedulers() - + def start(self): # Start the subprocesses: self.factory.start() @@ -500,7 +549,82 @@ def start(self): self.log.critical("Interrupted, Exiting...\n") finally: self.cleanup_connection_files() - + + def get_python_scheduler_args( + self, scheduler_name, factory, scheduler_class, monitor_url, identity=None + ): + return { + 'scheduler_class': scheduler_class, + 'in_addr': factory.client_url(scheduler_name), + 'out_addr': factory.engine_url(scheduler_name), + 'mon_addr': monitor_url, + 'not_addr': disambiguate_url(factory.client_url('notification')), + 'reg_addr': disambiguate_url(factory.client_url('registration')), + 'identity': identity if identity else bytes(scheduler_name, 'utf8'), + 'logname': 'scheduler', + 'loglevel': self.log_level, + 'log_url': self.log_url, + 'config': dict(self.config), + } + + def launch_broadcast_schedulers( + self, factory, monitor_url, children + ): + def launch_in_thread_or_process(scheduler_args): + + if 'Process' in self.mq_class: + # run the Python scheduler in a Process + q = Process( + target=launch_broadcast_scheduler, kwargs=scheduler_args + ) + q.daemon = True + children.append(q) + else: + # single-threaded Controller + scheduler_args['in_thread'] = True + launch_broadcast_scheduler(**scheduler_args) + + def recursively_start_schedulers(identity, depth): + outgoing_id1 = identity * 2 + 1 + outgoing_id2 = outgoing_id1 + 1 + is_leaf = depth == self.factory.broadcast_scheduler_depth + + scheduler_args = dict( + in_addr=factory.client_url(BroadcastScheduler.port_name, identity), + mon_addr=monitor_url, + not_addr=disambiguate_url(factory.client_url('notification')), + reg_addr=disambiguate_url(factory.client_url('registration')), + identity=identity, + config=dict(self.config), + loglevel=self.log_level, + log_url=self.log_url, + outgoing_ids=[outgoing_id1, outgoing_id2], + depth=depth, + is_leaf=is_leaf, + ) + if is_leaf: + scheduler_args.update( + out_addrs=[ + factory.engine_url( + BroadcastScheduler.port_name, + identity - factory.number_of_non_leaf_schedulers, + ) + ], + ) + else: + scheduler_args.update( + out_addrs=[ + factory.client_url(BroadcastScheduler.port_name, outgoing_id1), + factory.client_url(BroadcastScheduler.port_name, outgoing_id2), + ] + ) + launch_in_thread_or_process(scheduler_args) + if not is_leaf: + recursively_start_schedulers(outgoing_id1, depth + 1) + recursively_start_schedulers(outgoing_id2, depth + 1) + + recursively_start_schedulers(0, 0) + def launch_new_instance(*args, **kwargs): """Create and run the IPython controller""" @@ -508,10 +632,11 @@ def launch_new_instance(*args, **kwargs): # make sure we don't get called from a multiprocessing subprocess # this can result in infinite Controllers being started on Windows # which doesn't have a proper fork, so multiprocessing is wonky - + # this only comes up when IPython has been installed using vanilla # setuptools, and *not* distribute. import multiprocessing + p = multiprocessing.current_process() # the main process has name 'MainProcess' # subprocesses will have names like 'Process-1' diff --git a/ipyparallel/client/asyncresult.py b/ipyparallel/client/asyncresult.py index a59a866e3..369747a4e 100644 --- a/ipyparallel/client/asyncresult.py +++ b/ipyparallel/client/asyncresult.py @@ -23,7 +23,7 @@ from IPython import get_ipython from IPython.core.display import clear_output, display, display_pretty from ipyparallel import error -from ipyparallel.util import utcnow, compare_datetimes +from ipyparallel.util import utcnow, compare_datetimes, _parse_date from ipython_genutils.py3compat import string_types from .futures import MessageFuture, multi_future @@ -278,9 +278,28 @@ def r(self): """result property wrapper for `get(timeout=-1)`.""" return self.get() + _DATE_FIELDS = [ + "submitted", + "started", + "completed", + "received", + ] + + def _parse_metadata_dates(self): + """Ensure metadata date fields are parsed on access + + Rather than parsing timestamps from str->dt on receipt, + parse on access for compatibility. + """ + for md in self._metadata: + for key in self._DATE_FIELDS: + if isinstance(md.get(key, None), str): + md[key] = _parse_date(md[key]) + @property def metadata(self): """property for accessing execution metadata.""" + self._parse_metadata_dates() if self._single_result: return self._metadata[0] else: @@ -356,6 +375,7 @@ def __getitem__(self, key): # metadata proxy *does not* require that results are done self.wait(0) self.wait_for_output(0) + self._parse_metadata_dates() values = [ md[key] for md in self._metadata ] if self._single_result: return values[0] @@ -441,30 +461,30 @@ def timedelta(self, start, end, start_key=min, end_key=max): # not a list end = end_key(end) return compare_datetimes(end, start).total_seconds() - + @property def progress(self): """the number of tasks which have been completed at this point. - + Fractional progress would be given by 1.0 * ar.progress / len(ar) """ self.wait(0) return len(self) - len(set(self.msg_ids).intersection(self._client.outstanding)) - + @property def elapsed(self): """elapsed time since initial submission""" if self.ready(): return self.wall_time - + now = submitted = utcnow() - for msg_id in self.msg_ids: - if msg_id in self._client.metadata: - stamp = self._client.metadata[msg_id]['submitted'] - if stamp and stamp < submitted: - submitted = stamp + self._parse_metadata_dates() + for md in self._metadata: + stamp = md["submitted"] + if stamp and stamp < submitted: + submitted = stamp return compare_datetimes(now, submitted).total_seconds() - + @property @check_ready def serial_time(self): @@ -473,6 +493,7 @@ def serial_time(self): Computed as the sum of (completed-started) of each task """ t = 0 + self._parse_metadata_dates() for md in self._metadata: t += compare_datetimes(md['completed'], md['started']).total_seconds() return t diff --git a/ipyparallel/client/client.py b/ipyparallel/client/client.py index 1c28ce8e6..aff6c1a61 100644 --- a/ipyparallel/client/client.py +++ b/ipyparallel/client/client.py @@ -5,9 +5,11 @@ from __future__ import print_function +import threading + try: from collections.abc import Iterable -except ImportError: # py2 +except ImportError: # py2 from collections import Iterable import socket from concurrent.futures import Future @@ -37,10 +39,7 @@ from IPython.paths import get_ipython_dir from IPython.utils.path import compress_user from ipython_genutils.py3compat import cast_bytes, string_types, xrange, iteritems -from traitlets import ( - HasTraits, Instance, Unicode, - Dict, List, Bool, Set, Any -) +from traitlets import HasTraits, Instance, Unicode, Dict, List, Bool, Set, Any from decorator import decorator from ipyparallel import Reference @@ -54,11 +53,17 @@ from ..util import ioloop from .asyncresult import AsyncResult, AsyncHubResult from .futures import MessageFuture, multi_future -from .view import DirectView, LoadBalancedView +from .view import ( + DirectView, + LoadBalancedView, + BroadcastView, +) +import jupyter_client.session -#-------------------------------------------------------------------------- +jupyter_client.session.extract_dates = lambda obj: obj +# -------------------------------------------------------------------------- # Decorators for Client methods -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- @decorator @@ -74,9 +79,10 @@ def unpack_message(f, self, msg_parts): pprint(msg) return f(self, msg) -#-------------------------------------------------------------------------- + +# -------------------------------------------------------------------------- # Classes -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- _no_connection_file_msg = """ @@ -84,8 +90,10 @@ def unpack_message(f, self, msg_parts): Please double-check your profile and ensure that a cluster is running. """ + class ExecuteReply(RichOutput): """wrapper for finished Execute results""" + def __init__(self, msg_id, content, metadata): self.msg_id = msg_id self._content = content @@ -116,6 +124,7 @@ def _metadata(self): def display(self): from IPython.display import publish_display_data + publish_display_data(self.data, self.metadata) def _repr_mime_(self, mime): @@ -143,7 +152,7 @@ def __getattr__(self, key): return self.metadata[key] def __repr__(self): - execute_result = self.metadata['execute_result'] or {'data':{}} + execute_result = self.metadata['execute_result'] or {'data': {}} text_out = execute_result['data'].get('text/plain', '') if len(text_out) > 32: text_out = text_out[:29] + '...' @@ -151,7 +160,7 @@ def __repr__(self): return "" % (self.execution_count, text_out) def _plaintext(self): - execute_result = self.metadata['execute_result'] or {'data':{}} + execute_result = self.metadata['execute_result'] or {'data': {}} text_out = execute_result['data'].get('text/plain', '') if not text_out: @@ -173,14 +182,14 @@ def _plaintext(self): # add newline for multiline reprs text_out = '\n' + text_out - return u''.join([ - out, - u'Out[%i:%i]: ' % ( - self.metadata['engine_id'], self.execution_count - ), - normal, - text_out, - ]) + return u''.join( + [ + out, + u'Out[%i:%i]: ' % (self.metadata['engine_id'], self.execution_count), + normal, + text_out, + ] + ) def _repr_pretty_(self, p, cycle): p.text(self._plaintext()) @@ -194,27 +203,28 @@ class Metadata(dict): These objects have a strict set of keys - errors will raise if you try to add new keys. """ + def __init__(self, *args, **kwargs): dict.__init__(self) - md = {'msg_id' : None, - 'submitted' : None, - 'started' : None, - 'completed' : None, - 'received' : None, - 'engine_uuid' : None, - 'engine_id' : None, - 'follow' : None, - 'after' : None, - 'status' : None, - - 'execute_input' : None, - 'execute_result' : None, - 'error' : None, - 'stdout' : '', - 'stderr' : '', - 'outputs' : [], - 'data': {}, - } + md = { + 'msg_id': None, + 'submitted': None, + 'started': None, + 'completed': None, + 'received': None, + 'engine_uuid': None, + 'engine_id': None, + 'follow': None, + 'after': None, + 'status': None, + 'execute_input': None, + 'execute_result': None, + 'error': None, + 'stdout': '', + 'stderr': '', + 'outputs': [], + 'data': {}, + } self.update(md) self.update(dict(*args, **kwargs)) @@ -320,7 +330,6 @@ class Client(HasTraits): """ - block = Bool(False) outstanding = Set() results = Instance('collections.defaultdict', (dict,)) @@ -332,7 +341,8 @@ class Client(HasTraits): _io_loop = Any() _io_thread = Any() - profile=Unicode() + profile = Unicode() + def _profile_default(self): if BaseIPythonApplication.initialized(): # an IPython app *might* be running, try to get its profile @@ -345,32 +355,44 @@ def _profile_default(self): else: return u'default' - _outstanding_dict = Instance('collections.defaultdict', (set,)) _ids = List() - _connected=Bool(False) - _ssh=Bool(False) + _connected = Bool(False) + _ssh = Bool(False) _context = Instance('zmq.Context', allow_none=True) _config = Dict() - _engines=Instance(util.ReverseDict, (), {}) - _query_socket=Instance('zmq.Socket', allow_none=True) - _control_socket=Instance('zmq.Socket', allow_none=True) - _iopub_socket=Instance('zmq.Socket', allow_none=True) - _notification_socket=Instance('zmq.Socket', allow_none=True) - _mux_socket=Instance('zmq.Socket', allow_none=True) - _task_socket=Instance('zmq.Socket', allow_none=True) - _task_scheme=Unicode() + _engines = Instance(util.ReverseDict, (), {}) + _query_socket = Instance('zmq.Socket', allow_none=True) + _control_socket = Instance('zmq.Socket', allow_none=True) + _iopub_socket = Instance('zmq.Socket', allow_none=True) + _notification_socket = Instance('zmq.Socket', allow_none=True) + _mux_socket = Instance('zmq.Socket', allow_none=True) + _task_socket = Instance('zmq.Socket', allow_none=True) + _broadcast_socket = Instance('zmq.Socket', allow_none=True) + + _task_scheme = Unicode() _closed = False def __new__(self, *args, **kw): # don't raise on positional args return HasTraits.__new__(self, **kw) - def __init__(self, url_file=None, profile=None, profile_dir=None, ipython_dir=None, - context=None, debug=False, - sshserver=None, sshkey=None, password=None, paramiko=None, - timeout=10, cluster_id=None, **extra_args - ): + def __init__( + self, + url_file=None, + profile=None, + profile_dir=None, + ipython_dir=None, + context=None, + debug=False, + sshserver=None, + sshkey=None, + password=None, + paramiko=None, + timeout=10, + cluster_id=None, + **extra_args + ): if profile: super(Client, self).__init__(debug=debug, profile=profile) else: @@ -381,17 +403,21 @@ def __init__(self, url_file=None, profile=None, profile_dir=None, ipython_dir=No if 'url_or_file' in extra_args: url_file = extra_args['url_or_file'] - warnings.warn("url_or_file arg no longer supported, use url_file", DeprecationWarning) + warnings.warn( + "url_or_file arg no longer supported, use url_file", DeprecationWarning + ) if url_file and util.is_url(url_file): raise ValueError("single urls cannot be specified, url-files must be used.") self._setup_profile_dir(self.profile, profile_dir, ipython_dir) - no_file_msg = '\n'.join([ - "You have attempted to connect to an IPython Cluster but no Controller could be found.", - "Please double-check your configuration and ensure that a cluster is running.", - ]) + no_file_msg = '\n'.join( + [ + "You have attempted to connect to an IPython Cluster but no Controller could be found.", + "Please double-check your configuration and ensure that a cluster is running.", + ] + ) if self._cd is not None: if url_file is None: @@ -403,25 +429,25 @@ def __init__(self, url_file=None, profile=None, profile_dir=None, ipython_dir=No short = compress_user(url_file) if not os.path.exists(url_file): print("Waiting for connection file: %s" % short) - waiting_time = 0. + waiting_time = 0.0 while waiting_time < timeout: - time.sleep(min(timeout-waiting_time, 1)) + time.sleep(min(timeout - waiting_time, 1)) waiting_time += 1 if os.path.exists(url_file): break if not os.path.exists(url_file): - msg = '\n'.join([ - "Connection file %r not found." % short, - no_file_msg, - ]) + msg = '\n'.join( + ["Connection file %r not found." % short, no_file_msg] + ) raise IOError(msg) if url_file is None: raise IOError(no_file_msg) if not os.path.exists(url_file): # Connection file explicitly specified, but not found - raise IOError("Connection file %r not found. Is a controller running?" % \ - compress_user(url_file) + raise IOError( + "Connection file %r not found. Is a controller running?" + % compress_user(url_file) ) with open(url_file) as f: @@ -435,14 +461,22 @@ def __init__(self, url_file=None, profile=None, profile_dir=None, ipython_dir=No location = cfg.setdefault('location', None) - proto,addr = cfg['interface'].split('://') + proto, addr = cfg['interface'].split('://') addr = util.disambiguate_ip_address(addr, location) cfg['interface'] = "%s://%s" % (proto, addr) # turn interface,port into full urls: - for key in ('control', 'task', 'mux', 'iopub', 'notification', 'registration'): + for key in ( + 'control', + 'task', + 'mux', + 'iopub', + 'notification', + 'registration', + ): cfg[key] = cfg['interface'] + ':%i' % cfg[key] + cfg['broadcast'] = cfg['interface'] + ':%i' % cfg['broadcast'][0] url = cfg['registration'] if location is not None and addr == localhost(): @@ -452,18 +486,24 @@ def __init__(self, url_file=None, profile=None, profile_dir=None, ipython_dir=No if not is_local_ip(location_ip) and not sshserver: # load ssh from JSON *only* if the controller is not on # this machine - sshserver=cfg['ssh'] - if not is_local_ip(location_ip) and not sshserver and\ - location != socket.gethostname(): + sshserver = cfg['ssh'] + if ( + not is_local_ip(location_ip) + and not sshserver + and location != socket.gethostname() + ): # warn if no ssh specified, but SSH is probably needed # This is only a warning, because the most likely cause # is a local Controller on a laptop whose IP is dynamic - warnings.warn(""" + warnings.warn( + """ Controller appears to be listening on localhost, but not on this machine. If this is true, you should specify Client(...,sshserver='you@%s') - or instruct your controller to listen on an external IP.""" % location, - RuntimeWarning) + or instruct your controller to listen on an external IP.""" + % location, + RuntimeWarning, + ) elif not sshserver: # otherwise sync with cfg sshserver = cfg['ssh'] @@ -476,10 +516,11 @@ def __init__(self, url_file=None, profile=None, profile_dir=None, ipython_dir=No sshserver = addr if self._ssh and password is None: from zmq.ssh import tunnel + if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko): - password=False + password = False else: - password = getpass("SSH Password for %s: "%sshserver) + password = getpass("SSH Password for %s: " % sshserver) ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko) # configure and construct the session @@ -489,10 +530,12 @@ def __init__(self, url_file=None, profile=None, profile_dir=None, ipython_dir=No extra_args['key'] = cast_bytes(cfg['key']) extra_args['signature_scheme'] = cfg['signature_scheme'] except KeyError as exc: - msg = '\n'.join([ - "Connection file is invalid (missing '{}'), possibly from an old version of IPython.", - "If you are reusing connection files, remove them and start ipcontroller again." - ]) + msg = '\n'.join( + [ + "Connection file is invalid (missing '{}'), possibly from an old version of IPython.", + "If you are reusing connection files, remove them and start ipcontroller again.", + ] + ) raise ValueError(msg.format(exc.message)) self.session = Session(**extra_args) @@ -501,19 +544,28 @@ def __init__(self, url_file=None, profile=None, profile_dir=None, ipython_dir=No if self._ssh: from zmq.ssh import tunnel - tunnel.tunnel_connection(self._query_socket, cfg['registration'], sshserver, - timeout=timeout, **ssh_kwargs) + + tunnel.tunnel_connection( + self._query_socket, + cfg['registration'], + sshserver, + timeout=timeout, + **ssh_kwargs + ) else: self._query_socket.connect(cfg['registration']) self.session.debug = self.debug - self._notification_handlers = {'registration_notification' : self._register_engine, - 'unregistration_notification' : self._unregister_engine, - 'shutdown_notification' : lambda msg: self.close(), - } - self._queue_handlers = {'execute_reply' : self._handle_execute_reply, - 'apply_reply' : self._handle_apply_reply} + self._notification_handlers = { + 'registration_notification': self._register_engine, + 'unregistration_notification': self._unregister_engine, + 'shutdown_notification': lambda msg: self.close(), + } + self._queue_handlers = { + 'execute_reply': self._handle_execute_reply, + 'apply_reply': self._handle_apply_reply, + } try: self._connect(sshserver, ssh_kwargs, timeout) @@ -547,8 +599,7 @@ def _setup_profile_dir(self, profile, profile_dir, ipython_dir): pass elif profile is not None: try: - self._cd = ProfileDir.find_profile_dir_by_name( - ipython_dir, profile) + self._cd = ProfileDir.find_profile_dir_by_name(ipython_dir, profile) return except ProfileDirError: pass @@ -556,14 +607,17 @@ def _setup_profile_dir(self, profile, profile_dir, ipython_dir): def _update_engines(self, engines): """Update our engines dict and _ids from a dict of the form: {id:uuid}.""" - for k,v in iteritems(engines): + for k, v in iteritems(engines): eid = int(k) if eid not in self._engines: self._ids.append(eid) self._engines[eid] = v self._ids = sorted(self._ids) - if sorted(self._engines.keys()) != list(range(len(self._engines))) and \ - self._task_scheme == 'pure' and self._task_socket: + if ( + sorted(self._engines.keys()) != list(range(len(self._engines))) + and self._task_scheme == 'pure' + and self._task_socket + ): self._stop_scheduling_tasks() def _stop_scheduling_tasks(self): @@ -572,11 +626,15 @@ def _stop_scheduling_tasks(self): """ self._task_socket.close() self._task_socket = None - msg = "An engine has been unregistered, and we are using pure " +\ - "ZMQ task scheduling. Task farming will be disabled." + msg = ( + "An engine has been unregistered, and we are using pure " + + "ZMQ task scheduling. Task farming will be disabled." + ) if self.outstanding: - msg += " If you were running tasks when this happened, " +\ - "some `outstanding` msg_ids may never resolve." + msg += ( + " If you were running tasks when this happened, " + + "some `outstanding` msg_ids may never resolve." + ) warnings.warn(msg, RuntimeWarning) def _build_targets(self, targets): @@ -586,7 +644,9 @@ def _build_targets(self, targets): if not self._ids: # flush notification socket if no engines yet, just in case if not self.ids: - raise error.NoEnginesRegistered("Can't build targets without any engines") + raise error.NoEnginesRegistered( + "Can't build targets without any engines" + ) if targets is None: targets = self._ids @@ -594,21 +654,23 @@ def _build_targets(self, targets): if targets.lower() == 'all': targets = self._ids else: - raise TypeError("%r not valid str target, must be 'all'"%(targets)) + raise TypeError("%r not valid str target, must be 'all'" % (targets)) elif isinstance(targets, int): if targets < 0: targets = self.ids[targets] if targets not in self._ids: - raise IndexError("No such engine: %i"%targets) + raise IndexError("No such engine: %i" % targets) targets = [targets] if isinstance(targets, slice): indices = list(range(len(self._ids))[targets]) ids = self.ids - targets = [ ids[i] for i in indices ] + targets = [ids[i] for i in indices] if not isinstance(targets, (tuple, list, xrange)): - raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets))) + raise TypeError( + "targets by int/slice/collection of ints only, not %s" % (type(targets)) + ) return [cast_bytes(self._engines[t]) for t in targets], list(targets) @@ -619,11 +681,12 @@ def _connect(self, sshserver, ssh_kwargs, timeout): # Maybe allow reconnecting? if self._connected: return - self._connected=True + self._connected = True def connect_socket(s, url): if self._ssh: from zmq.ssh import tunnel + return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs) else: return s.connect(url) @@ -633,7 +696,7 @@ def connect_socket(s, url): poller = zmq.Poller() poller.register(self._query_socket, zmq.POLLIN) # poll expects milliseconds, timeout is seconds - evts = poller.poll(timeout*1000) + evts = poller.poll(timeout * 1000) if not evts: raise error.TimeoutError("Hub connection request timed out") idents, msg = self.session.recv(self._query_socket, mode=0) @@ -649,6 +712,11 @@ def connect_socket(s, url): self._task_socket = self._context.socket(zmq.DEALER) connect_socket(self._task_socket, cfg['task']) + self._broadcast_socket = self._context.socket(zmq.DEALER) + connect_socket( + self._broadcast_socket, cfg['broadcast'] + ) + self._notification_socket = self._context.socket(zmq.SUB) self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'') connect_socket(self._notification_socket, cfg['notification']) @@ -668,9 +736,9 @@ def connect_socket(s, url): self._start_io_thread() - #-------------------------------------------------------------------------- + # -------------------------------------------------------------------------- # handlers and callbacks for incoming messages - #-------------------------------------------------------------------------- + # -------------------------------------------------------------------------- def _unwrap_exception(self, content): """unwrap exception, and remap engine_id to int.""" @@ -687,17 +755,21 @@ def _extract_metadata(self, msg): parent = msg['parent_header'] msg_meta = msg['metadata'] content = msg['content'] - md = {'msg_id' : parent['msg_id'], - 'received' : util.utcnow(), - 'engine_uuid' : msg_meta.get('engine', None), - 'follow' : msg_meta.get('follow', []), - 'after' : msg_meta.get('after', []), - 'status' : content['status'], - } + md = { + 'msg_id': parent['msg_id'], + 'received': util.utcnow(), + 'engine_uuid': msg_meta.get('engine', None), + 'follow': msg_meta.get('follow', []), + 'after': msg_meta.get('after', []), + 'status': content['status'], + 'is_broadcast': msg_meta.get( + 'is_broadcast', False + ), + 'is_coalescing': msg_meta.get('is_coalescing', False), + } if md['engine_uuid'] is not None: md['engine_id'] = self._engines.get(md['engine_uuid'], None) - if 'date' in parent: md['submitted'] = parent['date'] if 'started' in msg_meta: @@ -710,7 +782,7 @@ def _register_engine(self, msg): """Register a new engine, and update our connection info.""" content = msg['content'] eid = content['id'] - d = {eid : content['uuid']} + d = {eid: content['uuid']} self._update_engines(d) def _unregister_engine(self, msg): @@ -741,7 +813,9 @@ def _handle_stranded_msgs(self, eid, uuid): # we already continue try: - raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id)) + raise error.EngineError( + "Engine %r died while running task %r" % (eid, msg_id) + ) except: content = error.wrap_exception() # build a fake message: @@ -761,9 +835,9 @@ def _handle_execute_reply(self, msg): future = self._futures.get(msg_id, None) if msg_id not in self.outstanding: if msg_id in self.history: - print("got stale result: %s"%msg_id) + print("got stale result: %s" % msg_id) else: - print("got unknown result: %s"%msg_id) + print("got unknown result: %s" % msg_id) else: self.outstanding.remove(msg_id) @@ -800,18 +874,26 @@ def _handle_execute_reply(self, msg): if future: future.set_result(self.results[msg_id]) + def _should_use_metadata_msg_id(self, msg): + md = msg['metadata'] + return md.get('is_broadcast', False) and md.get('is_coalescing', False) + def _handle_apply_reply(self, msg): """Save the reply to an apply_request into our results.""" parent = msg['parent_header'] - msg_id = parent['msg_id'] + if self._should_use_metadata_msg_id(msg): + msg_id = msg['metadata']['original_msg_id'] + else: + msg_id = parent['msg_id'] + future = self._futures.get(msg_id, None) if msg_id not in self.outstanding: if msg_id in self.history: - print("got stale result: %s"%msg_id) + print("got stale result: %s" % msg_id) print(self.results[msg_id]) print(msg) else: - print("got unknown result: %s"%msg_id) + print("got unknown result: %s" % msg_id) else: self.outstanding.remove(msg_id) content = msg['content'] @@ -827,7 +909,15 @@ def _handle_apply_reply(self, msg): # construct result: if content['status'] == 'ok': - self.results[msg_id] = serialize.deserialize_object(msg['buffers'])[0] + if md.get('is_coalescing', False): + deserialized_bufs = [] + bufs = msg['buffers'] + while bufs: + deserialized, bufs = serialize.deserialize_object(bufs) + deserialized_bufs.append(deserialized) + self.results[msg_id] = deserialized_bufs + else: + self.results[msg_id] = serialize.deserialize_object(msg['buffers'])[0] elif content['status'] == 'aborted': self.results[msg_id] = error.TaskAborted(msg_id) out_future = self._output_futures.get(msg_id) @@ -851,6 +941,7 @@ def _make_io_loop(self): if 'asyncio' in sys.modules: # tornado 5 on asyncio requires creating a new asyncio loop import asyncio + try: asyncio.get_event_loop() except RuntimeError: @@ -882,6 +973,11 @@ def _setup_streams(self): self._notification_stream = ZMQStream(self._notification_socket, self._io_loop) self._notification_stream.on_recv(self._dispatch_notification, copy=False) + self._broadcast_stream = ZMQStream( + self._broadcast_socket, self._io_loop + ) + self._broadcast_stream.on_recv(self._dispatch_reply, copy=False) + def _start_io_thread(self): """Start IOLoop in a background thread.""" evt = Event() @@ -895,7 +991,9 @@ def _start_io_thread(self): if not self._io_thread.is_alive(): raise RuntimeError("IO Loop failed to start") else: - raise RuntimeError("Start event was never set. Maybe a problem in the IO thread.") + raise RuntimeError( + "Start event was never set. Maybe a problem in the IO thread." + ) def _io_main(self, start_evt=None): """main loop for background IO thread""" @@ -960,9 +1058,9 @@ def _dispatch_iopub(self, msg): s = md[name] or '' md[name] = s + content['text'] elif msg_type == 'error': - md.update({'error' : self._unwrap_exception(content)}) + md.update({'error': self._unwrap_exception(content)}) elif msg_type == 'execute_input': - md.update({'execute_input' : content['code']}) + md.update({'execute_input': content['code']}) elif msg_type == 'display_data': md['outputs'].append(content) elif msg_type == 'execute_result': @@ -981,32 +1079,48 @@ def _dispatch_iopub(self, msg): # unhandled msg_type (status, etc.) pass - def _send(self, socket, msg_type, content=None, parent=None, ident=None, - buffers=None, track=False, header=None, metadata=None): + def create_message_futures(self, msg_id, async_result=False, track=False): + msg_future = MessageFuture(msg_id, track=track) + futures = [msg_future] + self._futures[msg_id] = msg_future + if async_result: + output = MessageFuture(msg_id) + # add future for output + self._output_futures[msg_id] = output + # hook up metadata + output.metadata = self.metadata[msg_id] + output.metadata['submitted'] = util.utcnow() + msg_future.output = output + futures.append(output) + return futures + + def _send( + self, + socket, + msg_type, + content=None, + parent=None, + ident=None, + buffers=None, + track=False, + header=None, + metadata=None, + ): """Send a message in the IO thread returns msg object""" if self._closed: raise IOError("Connections have been closed.") - msg = self.session.msg(msg_type, content=content, parent=parent, - header=header, metadata=metadata) + msg = self.session.msg( + msg_type, content=content, parent=parent, header=header, metadata=metadata + ) msg_id = msg['header']['msg_id'] - asyncresult = False - if msg_type in {'execute_request', 'apply_request'}: - asyncresult = True - # add future for output - self._output_futures[msg_id] = output = MessageFuture(msg_id) - # hook up metadata - output.metadata = self.metadata[msg_id] - - self._futures[msg_id] = future = MessageFuture(msg_id, track=track) - futures = [future] - - if asyncresult: - future.output = output - futures.append(output) - output.metadata['submitted'] = util.utcnow() + futures = self.create_message_futures( + msg_id, + async_result=msg_type in {'execute_request', 'apply_request'}, + track=track, + ) def cleanup(f): """Purge caches on Future resolution""" @@ -1018,13 +1132,15 @@ def cleanup(f): multi_future(futures).add_done_callback(cleanup) def _really_send(): - sent = self.session.send(socket, msg, track=track, buffers=buffers, ident=ident) + sent = self.session.send( + socket, msg, track=track, buffers=buffers, ident=ident + ) if track: - future.tracker.set_result(sent['tracker']) + futures[0].tracker.set_result(sent['tracker']) # hand off actual send to IO thread self._io_loop.add_callback(_really_send) - return future + return futures[0] def _send_recv(self, *args, **kwargs): """Send a message in the IO thread and return its reply""" @@ -1032,9 +1148,9 @@ def _send_recv(self, *args, **kwargs): future.wait() return future.result() - #-------------------------------------------------------------------------- + # -------------------------------------------------------------------------- # len, getitem - #-------------------------------------------------------------------------- + # -------------------------------------------------------------------------- def __len__(self): """len(client) returns # of engines.""" @@ -1045,7 +1161,9 @@ def __getitem__(self, key): Must be int, slice, or list/tuple/xrange of ints""" if not isinstance(key, (int, slice, tuple, list, xrange)): - raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key))) + raise TypeError( + "key by int/slice/iterable of ints only, not %s" % (type(key)) + ) else: return self.direct_view(key) @@ -1058,9 +1176,9 @@ def __iter__(self): for eid in self.ids: yield self.direct_view(eid) - #-------------------------------------------------------------------------- + # -------------------------------------------------------------------------- # Begin public methods - #-------------------------------------------------------------------------- + # -------------------------------------------------------------------------- @property def ids(self): @@ -1100,7 +1218,7 @@ def close(self, linger=None): if self._closed: return self._stop_io_thread() - snames = [ trait for trait in self.trait_names() if trait.endswith("socket") ] + snames = [trait for trait in self.trait_names() if trait.endswith("socket")] for name in snames: socket = getattr(self, name) if socket is not None and not socket.closed: @@ -1112,16 +1230,23 @@ def close(self, linger=None): def spin_thread(self, interval=1): """DEPRECATED, DOES NOTHING""" - warnings.warn("Client.spin_thread is deprecated now that IO is always in a thread", DeprecationWarning) + warnings.warn( + "Client.spin_thread is deprecated now that IO is always in a thread", + DeprecationWarning, + ) def stop_spin_thread(self): """DEPRECATED, DOES NOTHING""" - warnings.warn("Client.spin_thread is deprecated now that IO is always in a thread", DeprecationWarning) + warnings.warn( + "Client.spin_thread is deprecated now that IO is always in a thread", + DeprecationWarning, + ) def spin(self): """DEPRECATED, DOES NOTHING""" - warnings.warn("Client.spin is deprecated now that IO is in a thread", DeprecationWarning) - + warnings.warn( + "Client.spin is deprecated now that IO is in a thread", DeprecationWarning + ) def _await_futures(self, futures, timeout): """Wait for a collection of futures""" @@ -1175,8 +1300,9 @@ def wait(self, jobs=None, timeout=-1): # make a copy, so that we aren't passing a mutable collection to _futures_for_msgs theids = set(self.outstanding) else: - if isinstance(jobs, string_types + (int, AsyncResult)) \ - or not isinstance(jobs, Iterable): + if isinstance(jobs, string_types + (int, AsyncResult)) or not isinstance( + jobs, Iterable + ): jobs = [jobs] theids = set() for job in jobs: @@ -1196,22 +1322,22 @@ def wait(self, jobs=None, timeout=-1): futures.extend(self._futures_for_msgs(theids)) return self._await_futures(futures, timeout) - def wait_interactive(self, jobs=None, interval=1., timeout=-1.): + def wait_interactive(self, jobs=None, interval=1.0, timeout=-1.0): """Wait interactively for jobs If no job is specified, will wait for all outstanding jobs to complete. """ if jobs is None: # get futures for results - futures = [ f for f in self._futures.values() if hasattr(f, 'output') ] + futures = [f for f in self._futures.values() if hasattr(f, 'output')] ar = AsyncResult(self, futures, owner=False) else: ar = self._asyncresult_from_jobs(jobs, owner=False) return ar.wait_interactive(interval=interval, timeout=timeout) - #-------------------------------------------------------------------------- + # -------------------------------------------------------------------------- # Control methods - #-------------------------------------------------------------------------- + # -------------------------------------------------------------------------- def clear(self, targets=None, block=None): """Clear the namespace in target(s).""" @@ -1219,7 +1345,9 @@ def clear(self, targets=None, block=None): targets = self._build_targets(targets)[0] futures = [] for t in targets: - futures.append(self._send(self._control_stream, 'clear_request', content={}, ident=t)) + futures.append( + self._send(self._control_stream, 'clear_request', content={}, ident=t) + ) if not block: return multi_future(futures) for future in futures: @@ -1228,7 +1356,6 @@ def clear(self, targets=None, block=None): if msg['content']['status'] != 'ok': raise self._unwrap_exception(msg['content']) - def abort(self, jobs=None, targets=None, block=None): """Abort specific jobs from the execution queues of target(s). @@ -1251,9 +1378,13 @@ def abort(self, jobs=None, targets=None, block=None): msg_ids = [] if isinstance(jobs, string_types + (AsyncResult,)): jobs = [jobs] - bad_ids = [obj for obj in jobs if not isinstance(obj, string_types + (AsyncResult,))] + bad_ids = [ + obj for obj in jobs if not isinstance(obj, string_types + (AsyncResult,)) + ] if bad_ids: - raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0]) + raise TypeError( + "Invalid msg_id type %r, expected str or AsyncResult" % bad_ids[0] + ) for j in jobs: if isinstance(j, AsyncResult): msg_ids.extend(j.msg_ids) @@ -1262,8 +1393,11 @@ def abort(self, jobs=None, targets=None, block=None): content = dict(msg_ids=msg_ids) futures = [] for t in targets: - futures.append(self._send(self._control_stream, 'abort_request', - content=content, ident=t)) + futures.append( + self._send( + self._control_stream, 'abort_request', content=content, ident=t + ) + ) if not block: return multi_future(futures) @@ -1291,6 +1425,7 @@ def shutdown(self, targets='all', restart=False, hub=False, block=None): whether to restart engines after shutting them down. """ from ipyparallel.error import NoEnginesRegistered + if restart: raise NotImplementedError("Engine restart is not yet implemented") @@ -1304,8 +1439,14 @@ def shutdown(self, targets='all', restart=False, hub=False, block=None): futures = [] for t in targets: - futures.append(self._send(self._control_stream, 'shutdown_request', - content={'restart':restart},ident=t)) + futures.append( + self._send( + self._control_stream, + 'shutdown_request', + content={'restart': restart}, + ident=t, + ) + ) error = False if block or hub: for f in futures: @@ -1326,7 +1467,9 @@ def shutdown(self, targets='all', restart=False, hub=False, block=None): if error: raise error - def become_dask(self, targets='all', port=0, nanny=False, scheduler_args=None, **worker_args): + def become_dask( + self, targets='all', port=0, nanny=False, scheduler_args=None, **worker_args + ): """Turn the IPython cluster into a dask.distributed cluster Parameters @@ -1357,10 +1500,12 @@ def become_dask(self, targets='all', port=0, nanny=False, scheduler_args=None, * if scheduler_args is None: scheduler_args = {} else: - scheduler_args = dict(scheduler_args) # copy + scheduler_args = dict(scheduler_args) # copy # Start a Scheduler on the Hub: - reply = self._send_recv(self._query_stream, 'become_dask_request', + reply = self._send_recv( + self._query_stream, + 'become_dask_request', {'scheduler_args': scheduler_args}, ) if reply['content']['status'] != 'ok': @@ -1385,7 +1530,6 @@ def become_dask(self, targets='all', port=0, nanny=False, scheduler_args=None, * return client - def stop_dask(self, targets='all'): """Stop the distributed Scheduler and Workers started by become_dask. @@ -1409,9 +1553,9 @@ def stop_dask(self, targets='all'): become_distributed = become_dask stop_distributed = stop_dask - #-------------------------------------------------------------------------- + # -------------------------------------------------------------------------- # Execution related methods - #-------------------------------------------------------------------------- + # -------------------------------------------------------------------------- def _maybe_raise(self, result): """wrapper for maybe raising an exception if apply failed.""" @@ -1420,15 +1564,18 @@ def _maybe_raise(self, result): return result - def send_apply_request(self, socket, f, args=None, kwargs=None, metadata=None, track=False, - ident=None): + def send_apply_request( + self, socket, f, args=None, kwargs=None, metadata=None, track=False, ident=None + ): """construct and send an apply message via a socket. This is the principal method with which all engine execution is performed by views. """ if self._closed: - raise RuntimeError("Client cannot be used after its sockets have been closed") + raise RuntimeError( + "Client cannot be used after its sockets have been closed" + ) # defaults: args = args if args is not None else [] @@ -1437,21 +1584,30 @@ def send_apply_request(self, socket, f, args=None, kwargs=None, metadata=None, t # validate arguments if not callable(f) and not isinstance(f, (Reference, PrePickled)): - raise TypeError("f must be callable, not %s"%type(f)) + raise TypeError("f must be callable, not %s" % type(f)) if not isinstance(args, (tuple, list)): - raise TypeError("args must be tuple or list, not %s"%type(args)) + raise TypeError("args must be tuple or list, not %s" % type(args)) if not isinstance(kwargs, dict): - raise TypeError("kwargs must be dict, not %s"%type(kwargs)) + raise TypeError("kwargs must be dict, not %s" % type(kwargs)) if not isinstance(metadata, dict): - raise TypeError("metadata must be dict, not %s"%type(metadata)) + raise TypeError("metadata must be dict, not %s" % type(metadata)) - bufs = serialize.pack_apply_message(f, args, kwargs, + bufs = serialize.pack_apply_message( + f, + args, + kwargs, buffer_threshold=self.session.buffer_threshold, item_threshold=self.session.item_threshold, ) - future = self._send(socket, "apply_request", buffers=bufs, ident=ident, - metadata=metadata, track=track) + future = self._send( + socket, + "apply_request", + buffers=bufs, + ident=ident, + metadata=metadata, + track=track, + ) msg_id = future.msg_id self.outstanding.add(msg_id) @@ -1466,13 +1622,17 @@ def send_apply_request(self, socket, f, args=None, kwargs=None, metadata=None, t return future - def send_execute_request(self, socket, code, silent=True, metadata=None, ident=None): + def send_execute_request( + self, socket, code, silent=True, metadata=None, ident=None + ): """construct and send an execute request via a socket. """ if self._closed: - raise RuntimeError("Client cannot be used after its sockets have been closed") + raise RuntimeError( + "Client cannot be used after its sockets have been closed" + ) # defaults: metadata = metadata if metadata is not None else {} @@ -1485,9 +1645,9 @@ def send_execute_request(self, socket, code, silent=True, metadata=None, ident=N content = dict(code=code, silent=bool(silent), user_expressions={}) - - future = self._send(socket, "execute_request", content=content, ident=ident, - metadata=metadata) + future = self._send( + socket, "execute_request", content=content, ident=ident, metadata=metadata + ) msg_id = future.msg_id self.outstanding.add(msg_id) @@ -1503,9 +1663,9 @@ def send_execute_request(self, socket, code, silent=True, metadata=None, ident=N return future - #-------------------------------------------------------------------------- + # -------------------------------------------------------------------------- # construct a View object - #-------------------------------------------------------------------------- + # -------------------------------------------------------------------------- def load_balanced_view(self, targets=None, **kwargs): """construct a DirectView object. @@ -1524,8 +1684,9 @@ def load_balanced_view(self, targets=None, **kwargs): targets = None if targets is not None: targets = self._build_targets(targets)[1] - return LoadBalancedView(client=self, socket=self._task_stream, targets=targets, - **kwargs) + return LoadBalancedView( + client=self, socket=self._task_stream, targets=targets, **kwargs + ) def executor(self, targets=None): """Construct a PEP-3148 Executor with a LoadBalancedView @@ -1569,12 +1730,36 @@ def direct_view(self, targets='all', **kwargs): targets = self._build_targets(targets)[1] if single: targets = targets[0] - return DirectView(client=self, socket=self._mux_stream, targets=targets, - **kwargs) + return DirectView( + client=self, socket=self._mux_stream, targets=targets, **kwargs + ) + + def broadcast_view(self, targets='all', is_coalescing=False, **kwargs): + """construct a BroadCastView object. + If no arguments are specified, create a BroadCastView using all engines + using all engines. + + Parameters + ---------- + + targets: list,slice,int,etc. [default: use all engines] + The subset of engines across which to load-balance execution + is_coalescing: scheduler collects all messages from engines and returns them as one + kwargs: passed to BroadCastView + """ + targets = self._build_targets(targets)[1] + + bcast_view = BroadcastView( + client=self, + socket=self._broadcast_stream, + targets=targets, + ) + bcast_view.is_coalescing = is_coalescing + return bcast_view - #-------------------------------------------------------------------------- + # -------------------------------------------------------------------------- # Query methods - #-------------------------------------------------------------------------- + # -------------------------------------------------------------------------- def get_result(self, indices_or_msg_ids=None, block=None, owner=True): """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object. @@ -1656,14 +1841,14 @@ def resubmit(self, indices_or_msg_ids=None, metadata=None, block=None): indices_or_msg_ids = -1 theids = self._msg_ids_from_jobs(indices_or_msg_ids) - content = dict(msg_ids = theids) + content = dict(msg_ids=theids) reply = self._send_recv(self._query_stream, 'resubmit_request', content) content = reply['content'] if content['status'] != 'ok': raise self._unwrap_exception(content) mapping = content['resubmitted'] - new_ids = [ mapping[msg_id] for msg_id in theids ] + new_ids = [mapping[msg_id] for msg_id in theids] ar = AsyncHubResult(self, new_ids) @@ -1708,15 +1893,17 @@ def result_status(self, msg_ids, status_only=True): local_results[msg_id] = self.results[msg_id] theids.remove(msg_id) - if theids: # some not locally cached + if theids: # some not locally cached content = dict(msg_ids=theids, status_only=status_only) - reply = self._send_recv(self._query_stream, "result_request", content=content) + reply = self._send_recv( + self._query_stream, "result_request", content=content + ) content = reply['content'] if content['status'] != 'ok': raise self._unwrap_exception(content) buffers = reply['buffers'] else: - content = dict(completed=[],pending=[]) + content = dict(completed=[], pending=[]) content['completed'].extend(completed) @@ -1752,7 +1939,7 @@ def result_status(self, msg_ids, status_only=True): if rcontent['status'] == 'ok': if header['msg_type'] == 'apply_reply': - res,buffers = serialize.deserialize_object(buffers) + res, buffers = serialize.deserialize_object(buffers) elif header['msg_type'] == 'execute_reply': res = ExecuteReply(msg_id, rcontent, md) else: @@ -1801,10 +1988,14 @@ def queue_status(self, targets='all', verbose=False): def _msg_ids_from_target(self, targets=None): """Build a list of msg_ids from the list of engine targets""" - if not targets: # needed as _build_targets otherwise uses all engines + if not targets: # needed as _build_targets otherwise uses all engines return [] target_ids = self._build_targets(targets)[0] - return [md_id for md_id in self.metadata if self.metadata[md_id]["engine_uuid"] in target_ids] + return [ + md_id + for md_id in self.metadata + if self.metadata[md_id]["engine_uuid"] in target_ids + ] def _msg_ids_from_jobs(self, jobs=None): """Given a 'jobs' argument, convert it to a list of msg_ids. @@ -1904,7 +2095,9 @@ def purge_local_results(self, jobs=[], targets=[]): if jobs == 'all': if self.outstanding: - raise RuntimeError("Can't purge outstanding tasks: %s" % self.outstanding) + raise RuntimeError( + "Can't purge outstanding tasks: %s" % self.outstanding + ) self.results.clear() self.metadata.clear() self._futures.clear() @@ -1915,14 +2108,15 @@ def purge_local_results(self, jobs=[], targets=[]): msg_ids.update(self._msg_ids_from_jobs(jobs)) still_outstanding = self.outstanding.intersection(msg_ids) if still_outstanding: - raise RuntimeError("Can't purge outstanding tasks: %s" % still_outstanding) + raise RuntimeError( + "Can't purge outstanding tasks: %s" % still_outstanding + ) for mid in msg_ids: self.results.pop(mid, None) self.metadata.pop(mid, None) self._futures.pop(mid, None) self._output_futures.pop(mid, None) - def purge_hub_results(self, jobs=[], targets=[]): """Tell the Hub to forget results. @@ -1958,7 +2152,7 @@ def purge_hub_results(self, jobs=[], targets=[]): if content['status'] != 'ok': raise self._unwrap_exception(content) - def purge_results(self, jobs=[], targets=[]): + def purge_results(self, jobs=[], targets=[]): """Clears the cached results from both the hub and the local client Individual results can be purged by msg_id, or the entire @@ -2045,7 +2239,7 @@ def db_query(self, query, keys=None): buffers = reply['buffers'] has_bufs = buffer_lens is not None has_rbufs = result_buffer_lens is not None - for i,rec in enumerate(records): + for i, rec in enumerate(records): # unpack datetime objects for hkey in ('header', 'result_header'): if hkey in rec: @@ -2056,11 +2250,12 @@ def db_query(self, query, keys=None): # relink buffers if has_bufs: blen = buffer_lens[i] - rec['buffers'], buffers = buffers[:blen],buffers[blen:] + rec['buffers'], buffers = buffers[:blen], buffers[blen:] if has_rbufs: blen = result_buffer_lens[i] - rec['result_buffers'], buffers = buffers[:blen],buffers[blen:] + rec['result_buffers'], buffers = buffers[:blen], buffers[blen:] return records -__all__ = [ 'Client' ] + +__all__ = ['Client'] diff --git a/ipyparallel/client/view.py b/ipyparallel/client/view.py index aa99ce270..c5a18757f 100644 --- a/ipyparallel/client/view.py +++ b/ipyparallel/client/view.py @@ -6,6 +6,7 @@ from __future__ import absolute_import, print_function import imp +import threading import warnings from contextlib import contextmanager @@ -128,7 +129,7 @@ def __len__(self): return 1 else: return len(self.client) - + def set_flags(self, **kwargs): """set my attribute flags by keyword. @@ -578,7 +579,6 @@ def _really_apply(self, f, args=None, kwargs=None, targets=None, block=None, tra pass return ar - @sync_results def map(self, f, *sequences, **kwargs): """``view.map(f, *sequences, block=self.block)`` => list|AsyncMapResult @@ -819,6 +819,7 @@ def activate(self, suffix=''): Parameters ---------- + suffix: str [default: ''] The suffix, if any, for the magics. This allows you to have multiple views associated with parallel magics at the same time. @@ -838,6 +839,61 @@ def activate(self, suffix=''): ip.magics_manager.register(M) +class BroadcastView(DirectView): + is_coalescing = Bool(False) + + @sync_results + @save_ids + def _really_apply(self, f, args=None, kwargs=None, block=None, track=None, targets=None): + args = [] if args is None else args + kwargs = {} if kwargs is None else kwargs + block = self.block if block is None else block + track = self.track if track is None else track + targets = self.targets if targets is None else targets + idents, _targets = self.client._build_targets(targets) + futures = [] + + pf = PrePickled(f) + pargs = [PrePickled(arg) for arg in args] + pkwargs = {k: PrePickled(v) for k, v in kwargs.items()} + + s_idents = [ident.decode("utf8") for ident in idents] + + metadata = dict(targets=s_idents, is_broadcast=True, is_coalescing=self.is_coalescing) + if not self.is_coalescing: + original_future = self.client.send_apply_request( + self._socket, pf, pargs, pkwargs, + track=track, metadata=metadata) + original_msg_id = original_future.msg_id + + for ident in s_idents: + msg_and_target_id = f'{original_msg_id}_{ident}' + future = self.client.create_message_futures(msg_and_target_id, async_result=True, track=True) + self.client.outstanding.add(msg_and_target_id) + self.outstanding.add(msg_and_target_id) + futures.append(future[0]) + if original_msg_id in self.outstanding: + self.outstanding.remove(original_msg_id) + else: + message_future = self.client.send_apply_request( + self._socket, pf, pargs, pkwargs, + track=track, metadata=metadata + ) + self.client.outstanding.add(message_future.msg_id) + futures = message_future + + ar = AsyncResult(self.client, futures, fname=getname(f), targets=_targets, + owner=True) + if block: + try: + return ar.get() + except KeyboardInterrupt: + pass + return ar + + def map(self, f, *sequences, **kwargs): + pass + class LoadBalancedView(View): """An load-balancing View that only executes via the Task scheduler. @@ -956,6 +1012,7 @@ def set_flags(self, **kwargs): if t is not None: if t < 0: raise ValueError("Invalid timeout: %s"%t) + self.timeout = t @sync_results @@ -1055,7 +1112,6 @@ def _really_apply(self, f, args=None, kwargs=None, block=None, track=None, @save_ids def map(self, f, *sequences, **kwargs): """``view.map(f, *sequences, block=self.block, chunksize=1, ordered=True)`` => list|AsyncMapResult - Parallel version of builtin `map`, load-balanced by this View. `block`, and `chunksize` can be specified by keyword only. @@ -1165,5 +1221,5 @@ def shutdown(self, wait=True): if wait: self.view.wait() -__all__ = ['LoadBalancedView', 'DirectView', 'ViewExecutor'] +__all__ = ['LoadBalancedView', 'DirectView', 'ViewExecutor', 'BroadcastView'] diff --git a/ipyparallel/controller/broadcast_scheduler.py b/ipyparallel/controller/broadcast_scheduler.py new file mode 100644 index 000000000..0c73d9608 --- /dev/null +++ b/ipyparallel/controller/broadcast_scheduler.py @@ -0,0 +1,211 @@ +import logging +import os + +import zmq + +from traitlets import Integer, List, Bytes, Bool + +from ipyparallel import util +from ipyparallel.controller.scheduler import ( + Scheduler, + get_common_scheduler_streams, + ZMQStream, +) + + +class BroadcastScheduler(Scheduler): + port_name = 'broadcast' + accumulated_replies = {} + is_leaf = Bool(False) + connected_sub_scheduler_ids = List(Bytes()) + outgoing_streams = List() + + def start(self): + self.log.info( + 'Broadcast Scheduler started with pid=%s', + os.getpid(), + ) + self.client_stream.on_recv(self.dispatch_submission, copy=False) + if self.is_leaf: + super().start() + else: + for outgoing_stream in self.outgoing_streams: + outgoing_stream.on_recv(self.dispatch_result, copy=False) + + def send_to_targets(self, msg, original_msg_id, targets, idents, is_coalescing): + if is_coalescing: + self.accumulated_replies[original_msg_id] = { + bytes(target, 'utf8'): None for target in targets + } + + for target in targets: + new_msg = self.append_new_msg_id_to_msg( + self.get_new_msg_id(original_msg_id, target), target, idents, msg + ) + self.engine_stream.send_multipart(new_msg, copy=False) + + def send_to_sub_schedulers( + self, msg, original_msg_id, targets, idents, is_coalescing + ): + if is_coalescing: + self.accumulated_replies[original_msg_id] = { + scheduler_id: None for scheduler_id in self.connected_sub_scheduler_ids + } + + for i, scheduler_id in enumerate(self.connected_sub_scheduler_ids): + slice_start = i * len(targets) // len(self.connected_sub_scheduler_ids) + slice_end = (i + 1) * len(targets) // len(self.connected_sub_scheduler_ids) + targets_for_scheduler = targets[slice_start:slice_end] + if not targets_for_scheduler and is_coalescing: + del self.accumulated_replies[original_msg_id][scheduler_id] + msg['metadata']['targets'] = targets_for_scheduler + + new_msg = self.append_new_msg_id_to_msg( + self.get_new_msg_id(original_msg_id, scheduler_id), + scheduler_id, + idents, + msg, + ) + self.outgoing_streams[i].send_multipart(new_msg, copy=False) + + def coalescing_reply(self, raw_msg, msg, original_msg_id, outgoing_id): + if all( + msg is not None or stored_outgoing_id == outgoing_id + for stored_outgoing_id, msg in self.accumulated_replies[ + original_msg_id + ].items() + ): + new_msg = raw_msg[1:] + new_msg.extend( + [ + buffer + for msg_buffers in self.accumulated_replies[ + original_msg_id + ].values() + if msg_buffers + for buffer in msg_buffers + ] + ) + self.client_stream.send_multipart(new_msg, copy=False) + del self.accumulated_replies[original_msg_id] + else: + self.accumulated_replies[original_msg_id][outgoing_id] = msg['buffers'] + + @util.log_errors + def dispatch_submission(self, raw_msg): + try: + idents, msg_list = self.session.feed_identities(raw_msg, copy=False) + msg = self.session.deserialize(msg_list, content=False, copy=False) + except: + self.log.error( + f'broadcast::Invalid broadcast msg: {raw_msg}', exc_info=True + ) + return + metadata = msg['metadata'] + msg_id = msg['header']['msg_id'] + targets = metadata['targets'] + + is_coalescing = metadata['is_coalescing'] + + if 'original_msg_id' not in metadata: + metadata['original_msg_id'] = msg_id + + original_msg_id = metadata['original_msg_id'] + if self.is_leaf: + self.send_to_targets(msg, original_msg_id, targets, idents, is_coalescing) + else: + self.send_to_sub_schedulers( + msg, original_msg_id, targets, idents, is_coalescing + ) + + @util.log_errors + def dispatch_result(self, raw_msg): + try: + idents, msg = self.session.feed_identities(raw_msg, copy=False) + msg = self.session.deserialize(msg, content=False, copy=False) + outgoing_id = idents[0] + + except: + self.log.error( + f'broadcast::Invalid broadcast msg: {raw_msg}', exc_info=True + ) + return + + original_msg_id = msg['metadata']['original_msg_id'] + is_coalescing = msg['metadata']['is_coalescing'] + if is_coalescing: + self.coalescing_reply(raw_msg, msg, original_msg_id, outgoing_id) + else: + self.client_stream.send_multipart(raw_msg[1:], copy=False) + + +def get_id_with_prefix(identity): + return bytes(f'sub_scheduler_{identity}', 'utf8') + + +def launch_broadcast_scheduler( + in_addr, + out_addrs, + mon_addr, + not_addr, + reg_addr, + identity, + config=None, + loglevel=logging.DEBUG, + log_url=None, + is_leaf=False, + in_thread=False, + outgoing_ids=None, + depth=0, +): + config, ctx, loop, mons, nots, querys, log = get_common_scheduler_streams( + mon_addr, not_addr, reg_addr, config, 'scheduler', log_url, loglevel, in_thread + ) + + is_root = identity == 0 + sub_scheduler_id = get_id_with_prefix(identity) + + incoming_stream = ZMQStream(ctx.socket(zmq.ROUTER), loop) + util.set_hwm(incoming_stream, 0) + incoming_stream.setsockopt(zmq.IDENTITY, sub_scheduler_id) + + if is_root: + incoming_stream.bind(in_addr) + else: + incoming_stream.connect(in_addr) + + outgoing_streams = [] + for out_addr in out_addrs: + out = ZMQStream(ctx.socket(zmq.ROUTER), loop) + util.set_hwm(out, 0) + out.setsockopt(zmq.IDENTITY, sub_scheduler_id) + out.bind(out_addr) + outgoing_streams.append(out) + + scheduler_args = dict( + client_stream=incoming_stream, + mon_stream=mons, + notifier_stream=nots, + query_stream=querys, + loop=loop, + log=log, + config=config, + ) + if is_leaf: + scheduler_args.update(engine_stream=outgoing_streams[0], is_leaf=True) + else: + scheduler_args.update( + connected_sub_scheduler_ids=[ + get_id_with_prefix(identity) for identity in outgoing_ids + ], + outgoing_streams=outgoing_streams, + ) + + scheduler = BroadcastScheduler(**scheduler_args) + + scheduler.start() + if not in_thread: + try: + loop.start() + except KeyboardInterrupt: + scheduler.log.critical("Interrupted, exiting...") diff --git a/ipyparallel/controller/hub.py b/ipyparallel/controller/hub.py index 14ca4cffb..d8a45b1b2 100644 --- a/ipyparallel/controller/hub.py +++ b/ipyparallel/controller/hub.py @@ -14,20 +14,25 @@ import sys import time +from jupyter_client.jsonutil import parse_date from tornado.gen import coroutine, maybe_future import zmq from zmq.eventloop.zmqstream import ZMQStream # internal: from ipython_genutils.importstring import import_item + +from .broadcast_scheduler import BroadcastScheduler from ..util import extract_dates from jupyter_client.localinterfaces import localhost from ipython_genutils.py3compat import cast_bytes, unicode_type, iteritems, buffer_to_bytes_py2 from traitlets import ( HasTraits, Any, Instance, Integer, Unicode, Dict, Set, Tuple, - DottedObjectName, observe + DottedObjectName, default, observe, + List, ) +from datetime import datetime from ipyparallel import error, util from ipyparallel.factory import RegistrationFactory @@ -70,9 +75,15 @@ def empty_record(): 'stderr': '', } +def ensure_date_is_parsed(header): + if not isinstance(header['date'], datetime): + header['date'] = parse_date(header['date']) + def init_record(msg): """Initialize a TaskRecord based on a request.""" header = msg['header'] + + ensure_date_is_parsed(header) return { 'msg_id' : header['msg_id'], 'header' : header, @@ -107,7 +118,7 @@ class EngineConnector(HasTraits): pending: set of msg_ids stallback: tornado timeout for stalled registration """ - + id = Integer(0) uuid = Unicode() pending = Set() @@ -121,6 +132,7 @@ class EngineConnector(HasTraits): 'nodb' : 'ipyparallel.controller.dictdb.NoDB', } + class HubFactory(RegistrationFactory): """The Configurable for setting up a Hub.""" @@ -141,6 +153,38 @@ def _mux_default(self): def _task_default(self): return tuple(util.select_random_ports(2)) + broadcast_scheduler_depth = Integer( + 1, + config=True, + help="Depth of spanning tree schedulers", + ) + number_of_leaf_schedulers = Integer() + number_of_broadcast_schedulers = Integer() + number_of_non_leaf_schedulers = Integer() + + @default('number_of_leaf_schedulers') + def get_number_of_leaf_schedulers(self): + return 2 ** self.broadcast_scheduler_depth + + + @default('number_of_broadcast_schedulers') + def get_number_of_broadcast_schedulers(self): + return 2 * self.number_of_leaf_schedulers - 1 + + + @default('number_of_non_leaf_schedulers') + def get_number_of_non_leaf_schedulers(self): + return self.number_of_broadcast_schedulers - self.number_of_leaf_schedulers + + + broadcast = List(Integer(), config=True, + help="List of available ports for broadcast") + + def _broadcast_default(self): + return util.select_random_ports( + self.number_of_leaf_schedulers + self.number_of_broadcast_schedulers + ) + control = Tuple(Integer(), Integer(), config=True, help="""Client/Engine Port pair for Control queue""") @@ -203,7 +247,7 @@ def _engine_ip_default(self): registration_timeout = Integer(0, config=True, help="Engine registration timeout in seconds [default: max(30," "10*heartmonitor.period)]" ) - + def _registration_timeout_default(self): if self.heartmonitor is None: # early initialization, this value will be ignored @@ -246,14 +290,22 @@ def start(self): self.heartmonitor.start() self.log.info("Heartmonitor started") - def client_url(self, channel): + def client_url(self, channel, index=None): """return full zmq url for a named client channel""" - return "%s://%s:%i" % (self.client_transport, self.client_ip, self.client_info[channel]) - - def engine_url(self, channel): + return "%s://%s:%i" % ( + self.client_transport, + self.client_ip, + self.client_info[channel] if index is None else self.client_info[channel][index] + ) + + def engine_url(self, channel, index=None): """return full zmq url for a named engine channel""" - return "%s://%s:%i" % (self.engine_transport, self.engine_ip, self.engine_info[channel]) - + return "%s://%s:%i" % ( + self.engine_transport, + self.engine_ip, + self.engine_info[channel] if index is None else self.engine_info[channel][index] + ) + def init_hub(self): """construct Hub object""" @@ -262,9 +314,9 @@ def init_hub(self): if 'TaskScheduler.scheme_name' in self.config: scheme = self.config.TaskScheduler.scheme_name else: - from .scheduler import TaskScheduler + from .task_scheduler import TaskScheduler scheme = TaskScheduler.scheme_name.default_value - + # build connection dicts engine = self.engine_info = { 'interface' : "%s://%s" % (self.engine_transport, self.engine_ip), @@ -275,6 +327,8 @@ def init_hub(self): 'hb_pong' : self.hb[1], 'task' : self.task[1], 'iopub' : self.iopub[1], + BroadcastScheduler.port_name: + self.broadcast[-self.number_of_leaf_schedulers:], } client = self.client_info = { @@ -286,11 +340,13 @@ def init_hub(self): 'task_scheme' : scheme, 'iopub' : self.iopub[0], 'notification' : self.notifier_port, + BroadcastScheduler.port_name: + self.broadcast[:self.number_of_broadcast_schedulers], } - + self.log.debug("Hub engine addrs: %s", self.engine_info) self.log.debug("Hub client addrs: %s", self.client_info) - + # Registrar socket q = ZMQStream(ctx.socket(zmq.ROUTER), loop) util.set_hwm(q, 0) @@ -314,7 +370,7 @@ def init_hub(self): ) ### Client connections ### - + # Notifier socket n = ZMQStream(ctx.socket(zmq.PUB), loop) n.bind(self.client_url('notification')) @@ -367,9 +423,9 @@ class Hub(SessionFactory): client_info: dict of zmq connection information for engines to connect to the queues. """ - + engine_state_file = Unicode() - + # internal data structures: ids=Set() # engine IDs keytable=Dict() @@ -427,6 +483,8 @@ def __init__(self, **kwargs): b'out': self.save_queue_result, b'intask': self.save_task_request, b'outtask': self.save_task_result, + b'inbcast': self.save_broadcast_request, + b'outbcast': self.save_broadcast_result, b'tracktask': self.save_task_destination, b'incontrol': _passer, b'outcontrol': _passer, @@ -452,7 +510,7 @@ def __init__(self, **kwargs): self.resubmit.on_recv(lambda msg: None, copy=False) self.log.info("hub::created hub") - + def new_engine_id(self, requested_id=None): """generate a new engine integer id. @@ -475,7 +533,7 @@ def new_engine_id(self, requested_id=None): newid = self._idcounter self._idcounter += 1 return newid - + #----------------------------------------------------------------------------- # message validation #----------------------------------------------------------------------------- @@ -561,7 +619,7 @@ def dispatch_query(self, msg): self.session.send(self.query, "hub_error", ident=client_id, content=content, parent=msg) return - + try: f = handler(idents, msg) if f: @@ -686,6 +744,7 @@ def save_queue_result(self, idents, msg): # update record anyway, because the unregistration could have been premature rheader = msg['header'] md = msg['metadata'] + ensure_date_is_parsed(rheader) completed = util.ensure_timezone(rheader['date']) started = extract_dates(md.get('started', None)) result = { @@ -703,6 +762,74 @@ def save_queue_result(self, idents, msg): except Exception: self.log.error("DB Error updating record %r", msg_id, exc_info=True) + #--------------------- Broadcast traffic ------------------------------ + def save_broadcast_request(self, idents, msg): + client_id = idents[0] + try: + msg = self.session.deserialize(msg) + except Exception as e: + self.log.error(f'broadcast:: client {client_id} sent invalid broadcast message:' + f' {msg}', exc_info=True) + return + + record = init_record(msg) + + record['client_uuid'] = msg['header']['session'] + header = msg['header'] + msg_id = header['msg_id'] + self.pending.add(msg_id) + + try: + self.db.add_record(msg_id, record) + except Exception as e: + self.log.error(f'DB Error adding record {msg_id}', exc_info=True) + + def save_broadcast_result(self, idents, msg): + client_id = idents[0] + try: + msg = self.session.deserialize(msg) + except Exception as e: + self.log.error(f'broadcast::invalid broadcast result message send to {client_id}:' + f'') + + # save the result of a completed broadcast + parent = msg['parent_header'] + if not parent: + self.log.warn(f'Broadcast message {msg} had no parent') + return + msg_id = parent['msg_id'] + header = msg['header'] + md = msg['metadata'] + engine_uuid = md.get('engine', u'') + eid = self.by_ident.get(cast_bytes(engine_uuid), None) + status = md.get('status', None) + + if msg_id in self.pending: + self.log.info(f'broadcast:: broadcast {msg_id} finished on {eid}') + self.pending.remove(msg_id) + self.all_completed.add(msg_id) + if eid is not None and status != 'aborted': + self.completed[eid].append(msg_id) + ensure_date_is_parsed(header) + completed = util.ensure_timezone(header['date']) + started = extract_dates(md.get('started', None)) + result = { + 'result_header': header, + 'result_metadata': msg['metadata'], + 'result_content': msg['content'], + 'started': started, + 'completed': completed, + 'received': util.utcnow(), + 'engine_uuid': engine_uuid, + 'result_buffers': msg['buffers'] + } + + try: + self.db.update_record(msg_id, result) + except Exception as e: + self.log.error(f'DB Error saving broadcast result {msg_id}', msg_id, exc_info=True) + else: + self.log.debug(f'broadcast::unknown broadcast {msg_id} finished') #--------------------- Task Queue Traffic ------------------------------ @@ -780,7 +907,7 @@ def save_task_result(self, idents, msg): md = msg['metadata'] engine_uuid = md.get('engine', u'') eid = self.by_ident.get(cast_bytes(engine_uuid), None) - + status = md.get('status', None) if msg_id in self.pending: @@ -792,6 +919,7 @@ def save_task_result(self, idents, msg): self.completed[eid].append(msg_id) if msg_id in self.tasks[eid]: self.tasks[eid].remove(msg_id) + ensure_date_is_parsed(header) completed = util.ensure_timezone(header['date']) started = extract_dates(md.get('started', None)) result = { @@ -868,13 +996,13 @@ def save_iopub_message(self, topics, msg): msg_id = parent['msg_id'] msg_type = msg['header']['msg_type'] content = msg['content'] - + # ensure msg_id is in db try: rec = self.db.get_record(msg_id) except KeyError: rec = None - + # stream d = {} if msg_type == 'stream': @@ -894,7 +1022,7 @@ def save_iopub_message(self, topics, msg): if not d: return - + if rec is None: # new record rec = empty_record() @@ -904,7 +1032,7 @@ def save_iopub_message(self, topics, msg): update_record = self.db.add_record else: update_record = self.db.update_record - + try: update_record(msg_id, d) except Exception: @@ -984,7 +1112,7 @@ def register_engine(self, reg, msg): self.incoming_registrations[heart] = EngineConnector(id=eid,uuid=uuid,stallback=t) else: self.log.error("registration::registration %i failed: %r", eid, content['evalue']) - + return eid def unregister_engine(self, ident, msg): @@ -995,10 +1123,10 @@ def unregister_engine(self, ident, msg): self.log.error("registration::bad engine id for unregistration: %r", ident, exc_info=True) return self.log.info("registration::unregister_engine(%r)", eid) - + uuid = self.keytable[eid] content=dict(id=eid, uuid=uuid) - + #stop the heartbeats self.hearts.pop(uuid, None) self.heartmonitor.responses.discard(uuid) @@ -1073,7 +1201,7 @@ def finish_registration(self, heart): if self.notifier: self.session.send(self.notifier, "registration_notification", content=content) self.log.info("engine::Engine Connected: %i", eid) - + self._save_engine_state() def _purge_stalled_registration(self, heart): @@ -1090,7 +1218,7 @@ def _purge_stalled_registration(self, heart): def _cleanup_engine_state_file(self): """cleanup engine state mapping""" - + if os.path.exists(self.engine_state_file): self.log.debug("cleaning up engine state: %s", self.engine_state_file) try: @@ -1108,11 +1236,11 @@ def _save_engine_state(self): engines = {} for eid, ec in self.engines.items(): engines[eid] = ec.uuid - + state['engines'] = engines - + state['next_id'] = self._idcounter - + with open(self.engine_state_file, 'w') as f: json.dump(state, f) @@ -1121,12 +1249,12 @@ def _load_engine_state(self): """load engine mapping from JSON file""" if not os.path.exists(self.engine_state_file): return - + self.log.info("loading engine state from %s" % self.engine_state_file) - + with open(self.engine_state_file) as f: state = json.load(f) - + save_notifier = self.notifier self.notifier = None for eid, uuid in iteritems(state['engines']): @@ -1134,12 +1262,12 @@ def _load_engine_state(self): # start with this heart as current and beating: self.heartmonitor.responses.add(heart) self.heartmonitor.hearts.add(heart) - + self.incoming_registrations[heart] = EngineConnector(id=int(eid), uuid=uuid) self.finish_registration(heart) - + self.notifier = save_notifier - + self._idcounter = state['next_id'] #------------------------------------------------------------------------- @@ -1312,7 +1440,7 @@ def finish(reply): msg = self.session.msg(header['msg_type'], parent=header) msg_id = msg['msg_id'] msg['content'] = rec['content'] - + # use the old header, but update msg_id and timestamp fresh = msg['header'] header['msg_id'] = fresh['msg_id'] @@ -1331,7 +1459,7 @@ def finish(reply): return finish(error.wrap_exception()) finish(dict(status='ok', resubmitted=resubmitted)) - + # store the new IDs in the Task DB for msg_id, resubmit_id in iteritems(resubmitted): try: @@ -1345,7 +1473,7 @@ def _extract_record(self, rec): io_dict = {} for key in ('execute_input', 'execute_result', 'error', 'stdout', 'stderr'): io_dict[key] = rec[key] - content = { + content = { 'header': rec['header'], 'metadata': rec['metadata'], 'result_metadata': rec['result_metadata'], @@ -1498,4 +1626,3 @@ def stop_distributed(self, client_id, msg): self.session.send(self.query, "stop_distributed_reply", content=content, parent=msg, ident=client_id, ) - diff --git a/ipyparallel/controller/scheduler.py b/ipyparallel/controller/scheduler.py index 2988bec8e..e6a6d977a 100644 --- a/ipyparallel/controller/scheduler.py +++ b/ipyparallel/controller/scheduler.py @@ -9,16 +9,10 @@ # Distributed under the terms of the Modified BSD License. import logging -import time -from collections import deque -from random import randint, random -from types import FunctionType +from ipython_genutils.py3compat import cast_bytes +from traitlets import observe, Instance, Set, CBytes -try: - import numpy -except ImportError: - numpy = None import zmq from zmq.eventloop import zmqstream @@ -27,208 +21,54 @@ from decorator import decorator from traitlets.config.application import Application from traitlets.config.loader import Config -from traitlets import Instance, Dict, List, Set, Integer, Enum, CBytes, observe -from ipython_genutils.py3compat import cast_bytes -from ipyparallel import error, util -from ipyparallel.factory import SessionFactory +from ipyparallel import util from ipyparallel.util import connect_logger, local_logger, ioloop -from .dependency import Dependency +import jupyter_client.session + +jupyter_client.session.extract_dates = lambda obj: obj +from jupyter_client.session import SessionFactory + @decorator -def logged(f,self,*args,**kwargs): +def logged(f, self, *args, **kwargs): # print ("#--------------------") self.log.debug("scheduler::%s(*%s,**%s)", f.__name__, args, kwargs) # print ("#--") - return f(self,*args, **kwargs) - -#---------------------------------------------------------------------- -# Chooser functions -#---------------------------------------------------------------------- - -def plainrandom(loads): - """Plain random pick.""" - n = len(loads) - return randint(0,n-1) - -def lru(loads): - """Always pick the front of the line. - - The content of `loads` is ignored. - - Assumes LRU ordering of loads, with oldest first. - """ - return 0 - -def twobin(loads): - """Pick two at random, use the LRU of the two. - - The content of loads is ignored. - - Assumes LRU ordering of loads, with oldest first. - """ - n = len(loads) - a = randint(0,n-1) - b = randint(0,n-1) - return min(a,b) - -def weighted(loads): - """Pick two at random using inverse load as weight. - - Return the less loaded of the two. - """ - # weight 0 a million times more than 1: - weights = 1./(1e-6+numpy.array(loads)) - sums = weights.cumsum() - t = sums[-1] - x = random()*t - y = random()*t - idx = 0 - idy = 0 - while sums[idx] < x: - idx += 1 - while sums[idy] < y: - idy += 1 - if weights[idy] > weights[idx]: - return idy - else: - return idx - -def leastload(loads): - """Always choose the lowest load. - - If the lowest load occurs more than once, the first - occurance will be used. If loads has LRU ordering, this means - the LRU of those with the lowest load is chosen. - """ - return loads.index(min(loads)) - -#--------------------------------------------------------------------- -# Classes -#--------------------------------------------------------------------- - - -# store empty default dependency: -MET = Dependency([]) - - -class Job(object): - """Simple container for a job""" - def __init__(self, msg_id, raw_msg, idents, msg, header, metadata, - targets, after, follow, timeout): - self.msg_id = msg_id - self.raw_msg = raw_msg - self.idents = idents - self.msg = msg - self.header = header - self.metadata = metadata - self.targets = targets - self.after = after - self.follow = follow - self.timeout = timeout - - self.removed = False # used for lazy-delete from sorted queue - self.timestamp = time.time() - self.timeout_id = 0 - self.blacklist = set() - - def __lt__(self, other): - return self.timestamp < other.timestamp - - def __cmp__(self, other): - return cmp(self.timestamp, other.timestamp) - - @property - def dependents(self): - return self.follow.union(self.after) - - -class TaskScheduler(SessionFactory): - """Python TaskScheduler object. - - This is the simplest object that supports msg_id based - DAG dependencies. *Only* task msg_ids are checked, not - msg_ids of jobs submitted via the MUX queue. - - """ - - hwm = Integer(1, config=True, - help="""specify the High Water Mark (HWM) for the downstream - socket in the Task scheduler. This is the maximum number - of allowed outstanding tasks on each engine. - - The default (1) means that only one task can be outstanding on each - engine. Setting TaskScheduler.hwm=0 means there is no limit, and the - engines continue to be assigned tasks while they are working, - effectively hiding network latency behind computation, but can result - in an imbalance of work when submitting many heterogenous tasks all at - once. Any positive value greater than one is a compromise between the - two. - - """ - ) - scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'), - 'leastload', config=True, -help="""select the task scheduler scheme [default: Python LRU] - Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'""" - ) - - @observe('scheme_name') - def _scheme_name_changed(self, change): - self.log.debug("Using scheme %r" % change['new']) - self.scheme = globals()[change['new']] - - # input arguments: - scheme = Instance(FunctionType) # function for determining the destination - def _scheme_default(self): - return leastload - client_stream = Instance(zmqstream.ZMQStream, allow_none=True) # client-facing stream - engine_stream = Instance(zmqstream.ZMQStream, allow_none=True) # engine-facing stream - notifier_stream = Instance(zmqstream.ZMQStream, allow_none=True) # hub-facing sub stream - mon_stream = Instance(zmqstream.ZMQStream, allow_none=True) # hub-facing pub stream - query_stream = Instance(zmqstream.ZMQStream, allow_none=True) # hub-facing DEALER stream - - # internals: - queue = Instance(deque) # sorted list of Jobs - def _queue_default(self): - return deque() - queue_map = Dict() # dict by msg_id of Jobs (for O(1) access to the Queue) - graph = Dict() # dict by msg_id of [ msg_ids that depend on key ] - retries = Dict() # dict by msg_id of retries remaining (non-neg ints) - # waiting = List() # list of msg_ids ready to run, but haven't due to HWM - pending = Dict() # dict by engine_uuid of submitted tasks - completed = Dict() # dict by engine_uuid of completed tasks - failed = Dict() # dict by engine_uuid of failed tasks - destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed) - clients = Dict() # dict by msg_id for who submitted the task - targets = List() # list of target IDENTs - loads = List() # list of engine loads - # full = Set() # set of IDENTs that have HWM outstanding tasks - all_completed = Set() # set of all completed tasks - all_failed = Set() # set of all failed tasks - all_done = Set() # set of all finished tasks=union(completed,failed) - all_ids = Set() # set of all submitted task IDs - - ident = CBytes() # ZMQ identity. This should just be self.session.session - # but ensure Bytes + return f(self, *args, **kwargs) + + +class Scheduler(SessionFactory): + client_stream = Instance( + zmqstream.ZMQStream, allow_none=True + ) # client-facing stream + engine_stream = Instance( + zmqstream.ZMQStream, allow_none=True + ) # engine-facing stream + notifier_stream = Instance( + zmqstream.ZMQStream, allow_none=True + ) # hub-facing sub stream + mon_stream = Instance(zmqstream.ZMQStream, allow_none=True) # hub-facing pub stream + query_stream = Instance( + zmqstream.ZMQStream, allow_none=True + ) # hub-facing DEALER stream + + all_completed = Set() # set of all completed tasks + all_failed = Set() # set of all failed tasks + all_done = Set() # set of all finished tasks=union(completed,failed) + all_ids = Set() # set of all submitted task IDs + + ident = CBytes() # ZMQ identity. This should just be self.session.session + + # but ensure Bytes def _ident_default(self): return self.session.bsession def start(self): - self.query_stream.on_recv(self.dispatch_query_reply) - self.session.send(self.query_stream, "connection_request", {}) - self.engine_stream.on_recv(self.dispatch_result, copy=False) self.client_stream.on_recv(self.dispatch_submission, copy=False) - self._notification_handlers = dict( - registration_notification = self._register_engine, - unregistration_notification = self._unregister_engine - ) - self.notifier_stream.on_recv(self.dispatch_notification) - self.log.info("Scheduler started [%s]" % self.scheme_name) - def resume_receiving(self): """Resume accepting jobs.""" self.client_stream.on_recv(self.dispatch_submission, copy=False) @@ -238,560 +78,29 @@ def stop_receiving(self): Leave them in the ZMQ queue.""" self.client_stream.on_recv(None) - #----------------------------------------------------------------------- - # [Un]Registration Handling - #----------------------------------------------------------------------- - - - def dispatch_query_reply(self, msg): - """handle reply to our initial connection request""" - try: - idents,msg = self.session.feed_identities(msg) - except ValueError: - self.log.warn("task::Invalid Message: %r",msg) - return - try: - msg = self.session.deserialize(msg) - except ValueError: - self.log.warn("task::Unauthorized message from: %r"%idents) - return - - content = msg['content'] - for uuid in content.get('engines', {}).values(): - self._register_engine(cast_bytes(uuid)) - - - @util.log_errors - def dispatch_notification(self, msg): - """dispatch register/unregister events.""" - try: - idents,msg = self.session.feed_identities(msg) - except ValueError: - self.log.warn("task::Invalid Message: %r",msg) - return - try: - msg = self.session.deserialize(msg) - except ValueError: - self.log.warn("task::Unauthorized message from: %r"%idents) - return - - msg_type = msg['header']['msg_type'] + def dispatch_result(self, raw_msg): + raise NotImplementedError("Implement in subclasses") - handler = self._notification_handlers.get(msg_type, None) - if handler is None: - self.log.error("Unhandled message type: %r"%msg_type) - else: - try: - handler(cast_bytes(msg['content']['uuid'])) - except Exception: - self.log.error("task::Invalid notification msg: %r", msg, exc_info=True) - - def _register_engine(self, uid): - """New engine with ident `uid` became available.""" - # head of the line: - self.targets.insert(0,uid) - self.loads.insert(0,0) - - # initialize sets - self.completed[uid] = set() - self.failed[uid] = set() - self.pending[uid] = {} - - # rescan the graph: - self.update_graph(None) - - def _unregister_engine(self, uid): - """Existing engine with ident `uid` became unavailable.""" - if len(self.targets) == 1: - # this was our only engine - pass - - # handle any potentially finished tasks: - self.engine_stream.flush() - - # don't pop destinations, because they might be used later - # map(self.destinations.pop, self.completed.pop(uid)) - # map(self.destinations.pop, self.failed.pop(uid)) - - # prevent this engine from receiving work - idx = self.targets.index(uid) - self.targets.pop(idx) - self.loads.pop(idx) - - # wait 5 seconds before cleaning up pending jobs, since the results might - # still be incoming - if self.pending[uid]: - self.loop.add_timeout( - self.loop.time() + 5, - lambda: self.handle_stranded_tasks(uid), - ) - else: - self.completed.pop(uid) - self.failed.pop(uid) - - def handle_stranded_tasks(self, engine): - """Deal with jobs resident in an engine that died.""" - lost = self.pending[engine] - for msg_id in list(lost.keys()): - if msg_id not in lost: - # prevent double-handling of messages - continue - - raw_msg = lost[msg_id].raw_msg - idents, msg = self.session.feed_identities(raw_msg, copy=False) - parent = self.session.unpack(msg[1].bytes) - idents = [engine, idents[0]] - - # build fake error reply - try: - raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id)) - except: - content = error.wrap_exception() - # build fake metadata - md = dict( - status=u'error', - engine=engine.decode('ascii'), - date=util.utcnow(), - ) - msg = self.session.msg('apply_reply', content, parent=parent, metadata=md) - raw_reply = list(map(zmq.Message, self.session.serialize(msg, ident=idents))) - # and dispatch it - self.dispatch_result(raw_reply) + def dispatch_submission(self, raw_msg): + raise NotImplementedError("Implement in subclasses") - # finally scrub completed/failed lists - self.completed.pop(engine) - self.failed.pop(engine) + def append_new_msg_id_to_msg(self, new_id, target_id, idents, msg): + new_idents = [cast_bytes(target_id)] + idents + msg['header']['msg_id'] = new_id + new_msg_list = self.session.serialize(msg, ident=new_idents) + new_msg_list.extend(msg['buffers']) + return new_msg_list + def get_new_msg_id(self, original_msg_id, outgoing_id): + return f'{original_msg_id}_{outgoing_id if isinstance(outgoing_id, str) else outgoing_id.decode("utf8")}' - #----------------------------------------------------------------------- - # Job Submission - #----------------------------------------------------------------------- - - @util.log_errors - def dispatch_submission(self, raw_msg): - """Dispatch job submission to appropriate handlers.""" - # ensure targets up to date: - self.notifier_stream.flush() - try: - idents, msg = self.session.feed_identities(raw_msg, copy=False) - msg = self.session.deserialize(msg, content=False, copy=False) - except Exception: - self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True) - return - - - # send to monitor - self.mon_stream.send_multipart([b'intask']+raw_msg, copy=False) - - header = msg['header'] - md = msg['metadata'] - msg_id = header['msg_id'] - self.all_ids.add(msg_id) - - # get targets as a set of bytes objects - # from a list of unicode objects - targets = md.get('targets', []) - targets = set(map(cast_bytes, targets)) - - retries = md.get('retries', 0) - self.retries[msg_id] = retries - - # time dependencies - after = md.get('after', None) - if after: - after = Dependency(after) - if after.all: - if after.success: - after = Dependency(after.difference(self.all_completed), - success=after.success, - failure=after.failure, - all=after.all, - ) - if after.failure: - after = Dependency(after.difference(self.all_failed), - success=after.success, - failure=after.failure, - all=after.all, - ) - if after.check(self.all_completed, self.all_failed): - # recast as empty set, if `after` already met, - # to prevent unnecessary set comparisons - after = MET - else: - after = MET - - # location dependencies - follow = Dependency(md.get('follow', [])) - - timeout = md.get('timeout', None) - if timeout: - timeout = float(timeout) - - job = Job(msg_id=msg_id, raw_msg=raw_msg, idents=idents, msg=msg, - header=header, targets=targets, after=after, follow=follow, - timeout=timeout, metadata=md, - ) - # validate and reduce dependencies: - for dep in after,follow: - if not dep: # empty dependency - continue - # check valid: - if msg_id in dep or dep.difference(self.all_ids): - self.queue_map[msg_id] = job - return self.fail_unreachable(msg_id, error.InvalidDependency) - # check if unreachable: - if dep.unreachable(self.all_completed, self.all_failed): - self.queue_map[msg_id] = job - return self.fail_unreachable(msg_id) - - if after.check(self.all_completed, self.all_failed): - # time deps already met, try to run - if not self.maybe_run(job): - # can't run yet - if msg_id not in self.all_failed: - # could have failed as unreachable - self.save_unmet(job) - else: - self.save_unmet(job) - - def job_timeout(self, job, timeout_id): - """callback for a job's timeout. - - The job may or may not have been run at this point. - """ - if job.timeout_id != timeout_id: - # not the most recent call - return - now = time.time() - if job.timeout >= (now + 1): - self.log.warn("task %s timeout fired prematurely: %s > %s", - job.msg_id, job.timeout, now - ) - if job.msg_id in self.queue_map: - # still waiting, but ran out of time - self.log.info("task %r timed out", job.msg_id) - self.fail_unreachable(job.msg_id, error.TaskTimeout) - - def fail_unreachable(self, msg_id, why=error.ImpossibleDependency): - """a task has become unreachable, send a reply with an ImpossibleDependency - error.""" - if msg_id not in self.queue_map: - self.log.error("task %r already failed!", msg_id) - return - job = self.queue_map.pop(msg_id) - # lazy-delete from the queue - job.removed = True - for mid in job.dependents: - if mid in self.graph: - self.graph[mid].remove(msg_id) +ZMQStream = zmqstream.ZMQStream - try: - raise why() - except: - content = error.wrap_exception() - self.log.debug("task %r failing as unreachable with: %s", msg_id, content['ename']) - - self.all_done.add(msg_id) - self.all_failed.add(msg_id) - - msg = self.session.send(self.client_stream, 'apply_reply', content, - parent=job.header, ident=job.idents) - self.session.send(self.mon_stream, msg, ident=[b'outtask']+job.idents) - - self.update_graph(msg_id, success=False) - - def available_engines(self): - """return a list of available engine indices based on HWM""" - if not self.hwm: - return list(range(len(self.targets))) - available = [] - for idx in range(len(self.targets)): - if self.loads[idx] < self.hwm: - available.append(idx) - return available - - def maybe_run(self, job): - """check location dependencies, and run if they are met.""" - msg_id = job.msg_id - self.log.debug("Attempting to assign task %s", msg_id) - available = self.available_engines() - if not available: - # no engines, definitely can't run - return False - - if job.follow or job.targets or job.blacklist or self.hwm: - # we need a can_run filter - def can_run(idx): - # check hwm - if self.hwm and self.loads[idx] == self.hwm: - return False - target = self.targets[idx] - # check blacklist - if target in job.blacklist: - return False - # check targets - if job.targets and target not in job.targets: - return False - # check follow - return job.follow.check(self.completed[target], self.failed[target]) - - indices = list(filter(can_run, available)) - - if not indices: - # couldn't run - if job.follow.all: - # check follow for impossibility - dests = set() - relevant = set() - if job.follow.success: - relevant = self.all_completed - if job.follow.failure: - relevant = relevant.union(self.all_failed) - for m in job.follow.intersection(relevant): - dests.add(self.destinations[m]) - if len(dests) > 1: - self.queue_map[msg_id] = job - self.fail_unreachable(msg_id) - return False - if job.targets: - # check blacklist+targets for impossibility - job.targets.difference_update(job.blacklist) - if not job.targets or not job.targets.intersection(self.targets): - self.queue_map[msg_id] = job - self.fail_unreachable(msg_id) - return False - return False - else: - indices = None - - self.submit_task(job, indices) - return True - - def save_unmet(self, job): - """Save a message for later submission when its dependencies are met.""" - msg_id = job.msg_id - self.log.debug("Adding task %s to the queue", msg_id) - self.queue_map[msg_id] = job - self.queue.append(job) - # track the ids in follow or after, but not those already finished - for dep_id in job.after.union(job.follow).difference(self.all_done): - if dep_id not in self.graph: - self.graph[dep_id] = set() - self.graph[dep_id].add(msg_id) - - # schedule timeout callback - if job.timeout: - timeout_id = job.timeout_id = job.timeout_id + 1 - self.loop.add_timeout(time.time() + job.timeout, - lambda : self.job_timeout(job, timeout_id) - ) - - - def submit_task(self, job, indices=None): - """Submit a task to any of a subset of our targets.""" - if indices: - loads = [self.loads[i] for i in indices] - else: - loads = self.loads - idx = self.scheme(loads) - if indices: - idx = indices[idx] - target = self.targets[idx] - # print (target, map(str, msg[:3])) - # send job to the engine - self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False) - self.engine_stream.send_multipart(job.raw_msg, copy=False) - # update load - self.add_job(idx) - self.pending[target][job.msg_id] = job - # notify Hub - content = dict(msg_id=job.msg_id, engine_id=target.decode('ascii')) - self.session.send(self.mon_stream, 'task_destination', content=content, - ident=[b'tracktask',self.ident]) - - - #----------------------------------------------------------------------- - # Result Handling - #----------------------------------------------------------------------- - - - @util.log_errors - def dispatch_result(self, raw_msg): - """dispatch method for result replies""" - try: - idents,msg = self.session.feed_identities(raw_msg, copy=False) - msg = self.session.deserialize(msg, content=False, copy=False) - engine = idents[0] - try: - idx = self.targets.index(engine) - except ValueError: - pass # skip load-update for dead engines - else: - self.finish_job(idx) - except Exception: - self.log.error("task::Invalid result: %r", raw_msg, exc_info=True) - return - - md = msg['metadata'] - parent = msg['parent_header'] - if md.get('dependencies_met', True): - success = (md['status'] == 'ok') - msg_id = parent['msg_id'] - retries = self.retries[msg_id] - if not success and retries > 0: - # failed - self.retries[msg_id] = retries - 1 - self.handle_unmet_dependency(idents, parent) - else: - del self.retries[msg_id] - # relay to client and update graph - self.handle_result(idents, parent, raw_msg, success) - # send to Hub monitor - self.mon_stream.send_multipart([b'outtask']+raw_msg, copy=False) - else: - self.handle_unmet_dependency(idents, parent) - - def handle_result(self, idents, parent, raw_msg, success=True): - """handle a real task result, either success or failure""" - # first, relay result to client - engine = idents[0] - client = idents[1] - # swap_ids for ROUTER-ROUTER mirror - raw_msg[:2] = [client,engine] - # print (map(str, raw_msg[:4])) - self.client_stream.send_multipart(raw_msg, copy=False) - # now, update our data structures - msg_id = parent['msg_id'] - self.pending[engine].pop(msg_id) - if success: - self.completed[engine].add(msg_id) - self.all_completed.add(msg_id) - else: - self.failed[engine].add(msg_id) - self.all_failed.add(msg_id) - self.all_done.add(msg_id) - self.destinations[msg_id] = engine - - self.update_graph(msg_id, success) - - def handle_unmet_dependency(self, idents, parent): - """handle an unmet dependency""" - engine = idents[0] - msg_id = parent['msg_id'] - - job = self.pending[engine].pop(msg_id) - job.blacklist.add(engine) - - if job.blacklist == job.targets: - self.queue_map[msg_id] = job - self.fail_unreachable(msg_id) - elif not self.maybe_run(job): - # resubmit failed - if msg_id not in self.all_failed: - # put it back in our dependency tree - self.save_unmet(job) - - if self.hwm: - try: - idx = self.targets.index(engine) - except ValueError: - pass # skip load-update for dead engines - else: - if self.loads[idx] == self.hwm-1: - self.update_graph(None) - - def update_graph(self, dep_id=None, success=True): - """dep_id just finished. Update our dependency - graph and submit any jobs that just became runnable. - - Called with dep_id=None to update entire graph for hwm, but without finishing a task. - """ - # print ("\n\n***********") - # pprint (dep_id) - # pprint (self.graph) - # pprint (self.queue_map) - # pprint (self.all_completed) - # pprint (self.all_failed) - # print ("\n\n***********\n\n") - # update any jobs that depended on the dependency - msg_ids = self.graph.pop(dep_id, []) - - # recheck *all* jobs if - # a) we have HWM and an engine just become no longer full - # or b) dep_id was given as None - - if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]): - jobs = self.queue - using_queue = True - else: - using_queue = False - jobs = deque(sorted( self.queue_map[msg_id] for msg_id in msg_ids )) - - to_restore = [] - while jobs: - job = jobs.popleft() - if job.removed: - continue - msg_id = job.msg_id - - put_it_back = True - - if job.after.unreachable(self.all_completed, self.all_failed)\ - or job.follow.unreachable(self.all_completed, self.all_failed): - self.fail_unreachable(msg_id) - put_it_back = False - - elif job.after.check(self.all_completed, self.all_failed): # time deps met, maybe run - if self.maybe_run(job): - put_it_back = False - self.queue_map.pop(msg_id) - for mid in job.dependents: - if mid in self.graph: - self.graph[mid].remove(msg_id) - - # abort the loop if we just filled up all of our engines. - # avoids an O(N) operation in situation of full queue, - # where graph update is triggered as soon as an engine becomes - # non-full, and all tasks after the first are checked, - # even though they can't run. - if not self.available_engines(): - break - - if using_queue and put_it_back: - # popped a job from the queue but it neither ran nor failed, - # so we need to put it back when we are done - # make sure to_restore preserves the same ordering - to_restore.append(job) - - # put back any tasks we popped but didn't run - if using_queue: - self.queue.extendleft(to_restore) - - #---------------------------------------------------------------------- - # methods to be overridden by subclasses - #---------------------------------------------------------------------- - - def add_job(self, idx): - """Called after self.targets[idx] just got the job with header. - Override with subclasses. The default ordering is simple LRU. - The default loads are the number of outstanding jobs.""" - self.loads[idx] += 1 - for lis in (self.targets, self.loads): - lis.append(lis.pop(idx)) - - def finish_job(self, idx): - """Called after self.targets[idx] just finished a job. - Override with subclasses.""" - self.loads[idx] -= 1 - - -def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, reg_addr, config=None, - logname='root', log_url=None, loglevel=logging.DEBUG, - identity=b'task', in_thread=False): - - ZMQStream = zmqstream.ZMQStream +def get_common_scheduler_streams( + mon_addr, not_addr, reg_addr, config, logname, log_url, loglevel, in_thread +): if config: # unwrap dict back into Config config = Config(config) @@ -805,23 +114,13 @@ def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, reg_addr, config=Non # for safety with multiprocessing ctx = zmq.Context() loop = ioloop.IOLoop() - ins = ZMQStream(ctx.socket(zmq.ROUTER),loop) - util.set_hwm(ins, 0) - ins.setsockopt(zmq.IDENTITY, identity + b'_in') - ins.bind(in_addr) - - outs = ZMQStream(ctx.socket(zmq.ROUTER),loop) - util.set_hwm(outs, 0) - outs.setsockopt(zmq.IDENTITY, identity + b'_out') - outs.bind(out_addr) - mons = zmqstream.ZMQStream(ctx.socket(zmq.PUB),loop) - util.set_hwm(mons, 0) + mons = zmqstream.ZMQStream(ctx.socket(zmq.PUB), loop) mons.connect(mon_addr) - nots = zmqstream.ZMQStream(ctx.socket(zmq.SUB),loop) + nots = zmqstream.ZMQStream(ctx.socket(zmq.SUB), loop) nots.setsockopt(zmq.SUBSCRIBE, b'') nots.connect(not_addr) - querys = ZMQStream(ctx.socket(zmq.DEALER),loop) + querys = ZMQStream(ctx.socket(zmq.DEALER), loop) querys.connect(reg_addr) # setup logging. @@ -829,19 +128,61 @@ def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, reg_addr, config=Non log = Application.instance().log else: if log_url: - log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel) + log = connect_logger( + logname, ctx, log_url, root="scheduler", loglevel=loglevel + ) else: log = local_logger(logname, loglevel) + return config, ctx, loop, mons, nots, querys, log + + +def launch_scheduler( + scheduler_class, + in_addr, + out_addr, + mon_addr, + not_addr, + reg_addr, + config=None, + logname='root', + log_url=None, + loglevel=logging.DEBUG, + identity=None, + in_thread=False, +): + config, ctx, loop, mons, nots, querys, log = get_common_scheduler_streams( + mon_addr, not_addr, reg_addr, config, logname, log_url, loglevel, in_thread + ) + + util.set_hwm(mons, 0) + ins = ZMQStream(ctx.socket(zmq.ROUTER), loop) + util.set_hwm(ins, 0) + if identity: + ins.setsockopt(zmq.IDENTITY, identity + b'_in') + + ins.bind(in_addr) + + outs = ZMQStream(ctx.socket(zmq.ROUTER), loop) + util.set_hwm(outs, 0) + + if identity: + outs.setsockopt(zmq.IDENTITY, identity + b'_out') + outs.bind(out_addr) + + scheduler = scheduler_class( + client_stream=ins, + engine_stream=outs, + mon_stream=mons, + notifier_stream=nots, + query_stream=querys, + loop=loop, + log=log, + config=config, + ) - scheduler = TaskScheduler(client_stream=ins, engine_stream=outs, - mon_stream=mons, notifier_stream=nots, - query_stream=querys, - loop=loop, log=log, - config=config) scheduler.start() if not in_thread: try: loop.start() except KeyboardInterrupt: scheduler.log.critical("Interrupted, exiting...") - diff --git a/ipyparallel/controller/task_scheduler.py b/ipyparallel/controller/task_scheduler.py new file mode 100644 index 000000000..656de8a42 --- /dev/null +++ b/ipyparallel/controller/task_scheduler.py @@ -0,0 +1,791 @@ +import time +from collections import deque +from random import randint +from types import FunctionType + +import zmq +from ipython_genutils.py3compat import cast_bytes +from traitlets import Integer, Enum, observe, Instance, Dict, List + +from ipyparallel import util, error, Dependency +from ipyparallel.controller.scheduler import Scheduler + +try: + import numpy +except ImportError: + numpy = None + +# ---------------------------------------------------------------------- +# Chooser functions +# ---------------------------------------------------------------------- + + +def plainrandom(loads): + """Plain random pick.""" + n = len(loads) + return randint(0, n - 1) + + +def lru(loads): + """Always pick the front of the line. + + The content of `loads` is ignored. + + Assumes LRU ordering of loads, with oldest first. + """ + return 0 + + +def twobin(loads): + """Pick two at random, use the LRU of the two. + + The content of loads is ignored. + + Assumes LRU ordering of loads, with oldest first. + """ + n = len(loads) + a = randint(0, n - 1) + b = randint(0, n - 1) + return min(a, b) + + +def weighted(loads): + """Pick two at random using inverse load as weight. + + Return the less loaded of the two. + """ + # weight 0 a million times more than 1: + weights = 1.0 / (1e-6 + numpy.array(loads)) + sums = weights.cumsum() + t = sums[-1] + x = random() * t + y = random() * t + idx = 0 + idy = 0 + while sums[idx] < x: + idx += 1 + while sums[idy] < y: + idy += 1 + if weights[idy] > weights[idx]: + return idy + else: + return idx + + +def leastload(loads): + """Always choose the lowest load. + + If the lowest load occurs more than once, the first + occurance will be used. If loads has LRU ordering, this means + the LRU of those with the lowest load is chosen. + """ + return loads.index(min(loads)) + + +# --------------------------------------------------------------------- +# Classes +# --------------------------------------------------------------------- + +# store empty default dependency: +MET = Dependency([]) + + +class Job(object): + """Simple container for a job""" + + def __init__( + self, + msg_id, + raw_msg, + idents, + msg, + header, + metadata, + targets, + after, + follow, + timeout, + ): + self.msg_id = msg_id + self.raw_msg = raw_msg + self.idents = idents + self.msg = msg + self.header = header + self.metadata = metadata + self.targets = targets + self.after = after + self.follow = follow + self.timeout = timeout + + self.removed = False # used for lazy-delete from sorted queue + self.timestamp = time.time() + self.timeout_id = 0 + self.blacklist = set() + + def __lt__(self, other): + return self.timestamp < other.timestamp + + def __cmp__(self, other): + return cmp(self.timestamp, other.timestamp) + + @property + def dependents(self): + return self.follow.union(self.after) + + +class TaskScheduler(Scheduler): + """Python TaskScheduler object. + + This is the simplest object that supports msg_id based + DAG dependencies. *Only* task msg_ids are checked, not + msg_ids of jobs submitted via the MUX queue. + + """ + + hwm = Integer( + 1, + config=True, + help="""specify the High Water Mark (HWM) for the downstream + socket in the Task scheduler. This is the maximum number + of allowed outstanding tasks on each engine. + + The default (1) means that only one task can be outstanding on each + engine. Setting TaskScheduler.hwm=0 means there is no limit, and the + engines continue to be assigned tasks while they are working, + effectively hiding network latency behind computation, but can result + in an imbalance of work when submitting many heterogenous tasks all at + once. Any positive value greater than one is a compromise between the + two. + + """, + ) + + scheme_name = Enum( + ('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'), + 'leastload', + config=True, + help="""select the task scheduler scheme [default: Python LRU] + Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'""", + ) + + @observe('scheme_name') + def _scheme_name_changed(self, change): + self.log.debug("Using scheme %r" % change['new']) + self.scheme = globals()[change['new']] + + # input arguments: + scheme = Instance(FunctionType) # function for determining the destination + + @observe('scheme_name') + def _scheme_name_changed(self, change): + self.log.debug("Using scheme %r" % change['new']) + self.scheme = globals()[change['new']] + + # input arguments: + scheme = Instance(FunctionType) # function for determining the destination + + def _scheme_default(self): + return leastload + + # internals: + queue = Instance(deque) # sorted list of Jobs + + def _queue_default(self): + return deque() + + queue_map = Dict() # dict by msg_id of Jobs (for O(1) access to the Queue) + graph = Dict() # dict by msg_id of [ msg_ids that depend on key ] + retries = Dict() # dict by msg_id of retries remaining (non-neg ints) + # waiting = List() # list of msg_ids ready to run, but haven't due to HWM + pending = Dict() # dict by engine_uuid of submitted tasks + completed = Dict() # dict by engine_uuid of completed tasks + failed = Dict() # dict by engine_uuid of failed tasks + destinations = ( + Dict() + ) # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed) + clients = Dict() # dict by msg_id for who submitted the task + targets = List() # list of target IDENTs + loads = List() # list of engine loads + # full = Set() # set of IDENTs that have HWM outstanding tasks + + def start(self): + super().start() + self.query_stream.on_recv(self.dispatch_query_reply) + self.session.send(self.query_stream, "connection_request", {}) + self._notification_handlers = dict( + registration_notification=self._register_engine, + unregistration_notification=self._unregister_engine, + ) + self.log.info("Scheduler started [%s]" % self.scheme_name) + self.notifier_stream.on_recv(self.dispatch_notification) + + # ----------------------------------------------------------------------- + # [Un]Registration Handling + # ----------------------------------------------------------------------- + + def dispatch_query_reply(self, msg): + """handle reply to our initial connection request""" + try: + idents, msg = self.session.feed_identities(msg) + except ValueError: + self.log.warn("task::Invalid Message: %r", msg) + return + try: + msg = self.session.deserialize(msg) + except ValueError: + self.log.warn("task::Unauthorized message from: %r" % idents) + return + + content = msg['content'] + for uuid in content.get('engines', {}).values(): + self._register_engine(cast_bytes(uuid)) + + @util.log_errors + def dispatch_notification(self, msg): + """dispatch register/unregister events.""" + try: + idents, msg = self.session.feed_identities(msg) + except ValueError: + self.log.warn("task::Invalid Message: %r", msg) + return + try: + msg = self.session.deserialize(msg) + except ValueError: + self.log.warn("task::Unauthorized message from: %r" % idents) + return + + msg_type = msg['header']['msg_type'] + + handler = self._notification_handlers.get(msg_type, None) + if handler is None: + self.log.error("Unhandled message type: %r" % msg_type) + else: + try: + handler(cast_bytes(msg['content']['uuid'])) + except Exception: + self.log.error("task::Invalid notification msg: %r", msg, exc_info=True) + + def _register_engine(self, uid): + """New engine with ident `uid` became available.""" + # head of the line: + self.targets.insert(0, uid) + self.loads.insert(0, 0) + + # initialize sets + self.completed[uid] = set() + self.failed[uid] = set() + self.pending[uid] = {} + + # rescan the graph: + self.update_graph(None) + + def _unregister_engine(self, uid): + """Existing engine with ident `uid` became unavailable.""" + if len(self.targets) == 1: + # this was our only engine + pass + + # handle any potentially finished tasks: + self.engine_stream.flush() + + # don't pop destinations, because they might be used later + # map(self.destinations.pop, self.completed.pop(uid)) + # map(self.destinations.pop, self.failed.pop(uid)) + + # prevent this engine from receiving work + idx = self.targets.index(uid) + self.targets.pop(idx) + self.loads.pop(idx) + + # wait 5 seconds before cleaning up pending jobs, since the results might + # still be incoming + if self.pending[uid]: + self.loop.add_timeout( + self.loop.time() + 5, lambda: self.handle_stranded_tasks(uid) + ) + else: + self.completed.pop(uid) + self.failed.pop(uid) + + def handle_stranded_tasks(self, engine): + """Deal with jobs resident in an engine that died.""" + lost = self.pending[engine] + for msg_id in list(lost.keys()): + if msg_id not in lost: + # prevent double-handling of messages + continue + + raw_msg = lost[msg_id].raw_msg + idents, msg = self.session.feed_identities(raw_msg, copy=False) + parent = self.session.unpack(msg[1].bytes) + idents = [engine, idents[0]] + + # build fake error reply + try: + raise error.EngineError( + "Engine %r died while running task %r" % (engine, msg_id) + ) + except: + content = error.wrap_exception() + # build fake metadata + md = dict( + status=u'error', engine=engine.decode('ascii'), date=util.utcnow() + ) + msg = self.session.msg('apply_reply', content, parent=parent, metadata=md) + raw_reply = list( + map(zmq.Message, self.session.serialize(msg, ident=idents)) + ) + # and dispatch it + self.dispatch_result(raw_reply) + + # finally scrub completed/failed lists + self.completed.pop(engine) + self.failed.pop(engine) + + # ----------------------------------------------------------------------- + # Job Submission + # ----------------------------------------------------------------------- + + @util.log_errors + def dispatch_submission(self, raw_msg): + """Dispatch job submission to appropriate handlers.""" + # ensure targets up to date: + self.notifier_stream.flush() + try: + idents, msg = self.session.feed_identities(raw_msg, copy=False) + msg = self.session.deserialize(msg, content=False, copy=False) + except Exception: + self.log.error("task::Invaid task msg: %r" % raw_msg, exc_info=True) + return + + # send to monitor + self.mon_stream.send_multipart([b'intask'] + raw_msg, copy=False) + + header = msg['header'] + md = msg['metadata'] + msg_id = header['msg_id'] + self.all_ids.add(msg_id) + + # get targets as a set of bytes objects + # from a list of unicode objects + targets = md.get('targets', []) + targets = set(map(cast_bytes, targets)) + + retries = md.get('retries', 0) + self.retries[msg_id] = retries + + # time dependencies + after = md.get('after', None) + if after: + after = Dependency(after) + if after.all: + if after.success: + after = Dependency( + after.difference(self.all_completed), + success=after.success, + failure=after.failure, + all=after.all, + ) + if after.failure: + after = Dependency( + after.difference(self.all_failed), + success=after.success, + failure=after.failure, + all=after.all, + ) + if after.check(self.all_completed, self.all_failed): + # recast as empty set, if `after` already met, + # to prevent unnecessary set comparisons + after = MET + else: + after = MET + + # location dependencies + follow = Dependency(md.get('follow', [])) + + timeout = md.get('timeout', None) + if timeout: + timeout = float(timeout) + + job = Job( + msg_id=msg_id, + raw_msg=raw_msg, + idents=idents, + msg=msg, + header=header, + targets=targets, + after=after, + follow=follow, + timeout=timeout, + metadata=md, + ) + # validate and reduce dependencies: + for dep in after, follow: + if not dep: # empty dependency + continue + # check valid: + if msg_id in dep or dep.difference(self.all_ids): + self.queue_map[msg_id] = job + return self.fail_unreachable(msg_id, error.InvalidDependency) + # check if unreachable: + if dep.unreachable(self.all_completed, self.all_failed): + self.queue_map[msg_id] = job + return self.fail_unreachable(msg_id) + + if after.check(self.all_completed, self.all_failed): + # time deps already met, try to run + if not self.maybe_run(job): + # can't run yet + if msg_id not in self.all_failed: + # could have failed as unreachable + self.save_unmet(job) + else: + self.save_unmet(job) + + def job_timeout(self, job, timeout_id): + """callback for a job's timeout. + + The job may or may not have been run at this point. + """ + if job.timeout_id != timeout_id: + # not the most recent call + return + now = time.time() + if job.timeout >= (now + 1): + self.log.warn( + "task %s timeout fired prematurely: %s > %s", + job.msg_id, + job.timeout, + now, + ) + if job.msg_id in self.queue_map: + # still waiting, but ran out of time + self.log.info("task %r timed out", job.msg_id) + self.fail_unreachable(job.msg_id, error.TaskTimeout) + + def fail_unreachable(self, msg_id, why=error.ImpossibleDependency): + """a task has become unreachable, send a reply with an ImpossibleDependency + error.""" + if msg_id not in self.queue_map: + self.log.error("task %r already failed!", msg_id) + return + job = self.queue_map.pop(msg_id) + # lazy-delete from the queue + job.removed = True + for mid in job.dependents: + if mid in self.graph: + self.graph[mid].remove(msg_id) + + try: + raise why() + except: + content = error.wrap_exception() + self.log.debug( + "task %r failing as unreachable with: %s", msg_id, content['ename'] + ) + + self.all_done.add(msg_id) + self.all_failed.add(msg_id) + + msg = self.session.send( + self.client_stream, + 'apply_reply', + content, + parent=job.header, + ident=job.idents, + ) + self.session.send(self.mon_stream, msg, ident=[b'outtask'] + job.idents) + + self.update_graph(msg_id, success=False) + + def available_engines(self): + """return a list of available engine indices based on HWM""" + if not self.hwm: + return list(range(len(self.targets))) + available = [] + for idx in range(len(self.targets)): + if self.loads[idx] < self.hwm: + available.append(idx) + return available + + def maybe_run(self, job): + """check location dependencies, and run if they are met.""" + msg_id = job.msg_id + self.log.debug("Attempting to assign task %s", msg_id) + available = self.available_engines() + if not available: + # no engines, definitely can't run + return False + + if job.follow or job.targets or job.blacklist or self.hwm: + # we need a can_run filter + def can_run(idx): + # check hwm + if self.hwm and self.loads[idx] == self.hwm: + return False + target = self.targets[idx] + # check blacklist + if target in job.blacklist: + return False + # check targets + if job.targets and target not in job.targets: + return False + # check follow + return job.follow.check(self.completed[target], self.failed[target]) + + indices = list(filter(can_run, available)) + + if not indices: + # couldn't run + if job.follow.all: + # check follow for impossibility + dests = set() + relevant = set() + if job.follow.success: + relevant = self.all_completed + if job.follow.failure: + relevant = relevant.union(self.all_failed) + for m in job.follow.intersection(relevant): + dests.add(self.destinations[m]) + if len(dests) > 1: + self.queue_map[msg_id] = job + self.fail_unreachable(msg_id) + return False + if job.targets: + # check blacklist+targets for impossibility + job.targets.difference_update(job.blacklist) + if not job.targets or not job.targets.intersection(self.targets): + self.queue_map[msg_id] = job + self.fail_unreachable(msg_id) + return False + return False + else: + indices = None + + self.submit_task(job, indices) + return True + + def save_unmet(self, job): + """Save a message for later submission when its dependencies are met.""" + msg_id = job.msg_id + self.log.debug("Adding task %s to the queue", msg_id) + self.queue_map[msg_id] = job + self.queue.append(job) + # track the ids in follow or after, but not those already finished + for dep_id in job.after.union(job.follow).difference(self.all_done): + if dep_id not in self.graph: + self.graph[dep_id] = set() + self.graph[dep_id].add(msg_id) + + # schedule timeout callback + if job.timeout: + timeout_id = job.timeout_id = job.timeout_id + 1 + self.loop.add_timeout( + time.time() + job.timeout, lambda: self.job_timeout(job, timeout_id) + ) + + def submit_task(self, job, indices=None): + """Submit a task to any of a subset of our targets.""" + if indices: + loads = [self.loads[i] for i in indices] + else: + loads = self.loads + idx = self.scheme(loads) + if indices: + idx = indices[idx] + target = self.targets[idx] + # print (target, map(str, msg[:3])) + # send job to the engine + self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False) + self.engine_stream.send_multipart(job.raw_msg, copy=False) + # update load + self.add_job(idx) + self.pending[target][job.msg_id] = job + # notify Hub + content = dict(msg_id=job.msg_id, engine_id=target.decode('ascii')) + self.session.send( + self.mon_stream, + 'task_destination', + content=content, + ident=[b'tracktask', self.ident], + ) + + # ----------------------------------------------------------------------- + # Result Handling + # ----------------------------------------------------------------------- + + @util.log_errors + def dispatch_result(self, raw_msg): # maybe_dispatch_reults ? + """dispatch method for result replies""" + try: + idents, msg = self.session.feed_identities(raw_msg, copy=False) + msg = self.session.deserialize(msg, content=False, copy=False) + engine = idents[0] + try: + idx = self.targets.index(engine) + except ValueError: + pass # skip load-update for dead engines + else: + self.finish_job(idx) + except Exception: + self.log.error("task::Invalid result: %r", raw_msg, exc_info=True) + return + + md = msg['metadata'] + parent = msg['parent_header'] + if md.get('dependencies_met', True): + success = md['status'] == 'ok' + msg_id = parent['msg_id'] + retries = self.retries[msg_id] + if not success and retries > 0: + # failed + self.retries[msg_id] = retries - 1 + self.handle_unmet_dependency(idents, parent) + else: + del self.retries[msg_id] + # relay to client and update graph + self.handle_result(idents, parent, raw_msg, success) + # send to Hub monitor + self.mon_stream.send_multipart([b'outtask'] + raw_msg, copy=False) + else: + self.handle_unmet_dependency(idents, parent) + + def handle_result(self, idents, parent, raw_msg, success=True): + """handle a real task result, either success or failure""" + # first, relay result to client + engine = idents[0] + client = idents[1] + # swap_ids for ROUTER-ROUTER mirror + raw_msg[:2] = [client, engine] + # print (map(str, raw_msg[:4])) + self.client_stream.send_multipart(raw_msg, copy=False) + # now, update our data structures + msg_id = parent['msg_id'] + self.pending[engine].pop(msg_id) + if success: + self.completed[engine].add(msg_id) + self.all_completed.add(msg_id) + else: + self.failed[engine].add(msg_id) + self.all_failed.add(msg_id) + self.all_done.add(msg_id) + self.destinations[msg_id] = engine + + self.update_graph(msg_id, success) + + def handle_unmet_dependency(self, idents, parent): + """handle an unmet dependency""" + engine = idents[0] + msg_id = parent['msg_id'] + + job = self.pending[engine].pop(msg_id) + job.blacklist.add(engine) + + if job.blacklist == job.targets: + self.queue_map[msg_id] = job + self.fail_unreachable(msg_id) + elif not self.maybe_run(job): + # resubmit failed + if msg_id not in self.all_failed: + # put it back in our dependency tree + self.save_unmet(job) + + if self.hwm: + try: + idx = self.targets.index(engine) + except ValueError: + pass # skip load-update for dead engines + else: + if self.loads[idx] == self.hwm - 1: + self.update_graph(None) + + def update_graph(self, dep_id=None, success=True): + """dep_id just finished. Update our dependency + graph and submit any jobs that just became runnable. + + Called with dep_id=None to update entire graph for hwm, but without finishing a task. + """ + # print ("\n\n***********") + # pprint (dep_id) + # pprint (self.graph) + # pprint (self.queue_map) + # pprint (self.all_completed) + # pprint (self.all_failed) + # print ("\n\n***********\n\n") + # update any jobs that depended on the dependency + msg_ids = self.graph.pop(dep_id, []) + + # recheck *all* jobs if + # a) we have HWM and an engine just become no longer full + # or b) dep_id was given as None + + if ( + dep_id is None + or self.hwm + and any([load == self.hwm - 1 for load in self.loads]) + ): + jobs = self.queue + using_queue = True + else: + using_queue = False + jobs = deque(sorted(self.queue_map[msg_id] for msg_id in msg_ids)) + + to_restore = [] + while jobs: + job = jobs.popleft() + if job.removed: + continue + msg_id = job.msg_id + + put_it_back = True + + if job.after.unreachable( + self.all_completed, self.all_failed + ) or job.follow.unreachable(self.all_completed, self.all_failed): + self.fail_unreachable(msg_id) + put_it_back = False + + elif job.after.check( + self.all_completed, self.all_failed + ): # time deps met, maybe run + if self.maybe_run(job): + put_it_back = False + self.queue_map.pop(msg_id) + for mid in job.dependents: + if mid in self.graph: + self.graph[mid].remove(msg_id) + + # abort the loop if we just filled up all of our engines. + # avoids an O(N) operation in situation of full queue, + # where graph update is triggered as soon as an engine becomes + # non-full, and all tasks after the first are checked, + # even though they can't run. + if not self.available_engines(): + break + + if using_queue and put_it_back: + # popped a job from the queue but it neither ran nor failed, + # so we need to put it back when we are done + # make sure to_restore preserves the same ordering + to_restore.append(job) + + # put back any tasks we popped but didn't run + if using_queue: + self.queue.extendleft(to_restore) + + # ---------------------------------------------------------------------- + # methods to be overridden by subclasses + # ---------------------------------------------------------------------- + + def add_job(self, idx): + """Called after self.targets[idx] just got the job with header. + Override with subclasses. The default ordering is simple LRU. + The default loads are the number of outstanding jobs.""" + self.loads[idx] += 1 + for lis in (self.targets, self.loads): + lis.append(lis.pop(idx)) + + def finish_job(self, idx): + """Called after self.targets[idx] just finished a job. + Override with subclasses.""" + self.loads[idx] -= 1 diff --git a/ipyparallel/engine/engine.py b/ipyparallel/engine/engine.py index 58ecd7224..c44bc2900 100644 --- a/ipyparallel/engine/engine.py +++ b/ipyparallel/engine/engine.py @@ -205,7 +205,10 @@ def complete_registration(self, msg, connect, maybe_tunnel): def url(key): """get zmq url for given channel""" return str(info["interface"] + ":%i" % info[key]) - + + def urls(key): + return [f'{info["interface"]}:{port}' for port in info[key]] + if content['status'] == 'ok': if self.id is not None and content['id'] != self.id: self.log.warning("Did not get the requested id: %i != %i", content['id'], self.id) @@ -232,13 +235,18 @@ def url(key): heart.start() # create Shell Connections (MUX, Task, etc.): - shell_addrs = url('mux'), url('task') + shell_addrs = [url('mux'), url('task')] + urls('broadcast') + + self.log.info(f'ENGINE: shell_addrs: {shell_addrs}') # Use only one shell stream for mux and tasks stream = zmqstream.ZMQStream(ctx.socket(zmq.ROUTER), loop) stream.setsockopt(zmq.IDENTITY, identity) + self.log.debug("Setting shell identity %r", identity) + shell_streams = [stream] for addr in shell_addrs: + self.log.info("Connecting shell to %s", addr) connect(stream, addr) # control stream: diff --git a/ipyparallel/engine/kernel.py b/ipyparallel/engine/kernel.py index 31112d310..69ee17024 100644 --- a/ipyparallel/engine/kernel.py +++ b/ipyparallel/engine/kernel.py @@ -54,10 +54,14 @@ def should_handle(self, stream, msg, idents): def init_metadata(self, parent): """init metadata dict, for execute/apply_reply""" + parent_metadata = parent.get('metadata', {}) return { 'started': utcnow(), 'dependencies_met' : True, 'engine' : self.ident, + 'is_broadcast': parent_metadata.get('is_broadcast', False), + 'is_coalescing': parent_metadata.get('is_coalescing', False), + 'original_msg_id': parent_metadata.get('original_msg_id', ''), } def finish_metadata(self, parent, metadata, reply_content): @@ -86,16 +90,16 @@ def apply_request(self, stream, ident, parent): return md = self.init_metadata(parent) - reply_content, result_buf = self.do_apply(content, bufs, msg_id, md) # put 'ok'/'error' status in header, for scheduler introspection: md = self.finish_metadata(parent, md, reply_content) # flush i/o + self.log.info(f'ENGINE apply request, ident: {ident}') sys.stdout.flush() sys.stderr.flush() - + self.log.debug('engine: sending apply_reply') self.session.send(stream, u'apply_reply', reply_content, parent=parent, ident=ident, buffers=result_buf, metadata=md) diff --git a/ipyparallel/util.py b/ipyparallel/util.py index fea5761c6..88e1c26ee 100644 --- a/ipyparallel/util.py +++ b/ipyparallel/util.py @@ -123,7 +123,7 @@ def log_errors(f, self, *args, **kwargs): """ try: return f(self, *args, **kwargs) - except Exception: + except Exception as e: self.log.error("Uncaught exception in %r" % f, exc_info=True)