-
Notifications
You must be signed in to change notification settings - Fork 75k
Closed
Description
System information
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow): no
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): macOS High Sierra
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device: N/A
- TensorFlow installed from (source or binary): tf-nightly
- TensorFlow version (use command below): ('v1.9.0-rc2-4081-g626bc997c2', '1.11.0-dev20180913')
- Python version: Python 2.7.15
- Bazel version (if compiling from source): N/A
- GCC/Compiler version (if compiling from source): N/A
- CUDA/cuDNN version: N/A
- GPU model and memory: N/A
- Exact command to reproduce: https://gist.github.com/df3df82f7ae8f47b6288fc42eb8c8b17
Describe the problem
Invoke tf.estimator.train_and_evaluate with CollectiveAllReduceStrategy fails on CPU-only worker nodes, with the following message:
InternalError: ScopedAllocatorMgr not supported on device /job:worker/replica:0/task:0/device:CPU:0
Source code / logs
from tensorflow.contrib.distribute import CollectiveAllReduceStrategy
from tensorflow.contrib.distribute import DistributeConfig
distribution = CollectiveAllReduceStrategy(num_gpus_per_worker=0)
config = tf.estimator.RunConfig(
experimental_distribute=DistributeConfig(
train_distribute=distribution,
remote_cluster={
'worker': ['localhost:5000', 'localhost:5001'],
},
)
)
estimator = tf.estimator.Estimator(model_fn=model_fn, config=config)
tf.estimator.train_and_evaluate(estimator=estimator, train_spec=train_spec, eval_spec=eval_spec)INFO:tensorflow:CollectiveAllReduceStrategy with local_devices = ['/device:CPU:0']
INFO:tensorflow:Initializing RunConfig with distribution strategies.
INFO:tensorflow:RunConfig initialized for Distribute Coordinator with STANDALONE_CLIENT mode
WARNING:tensorflow:Using temporary folder as model directory: /var/folders/gn/sjntndrs1fs22kfr302697mr0000gn/T/tmpeIE_x6
INFO:tensorflow:Using config: {'_save_checkpoints_secs': 600, '_num_ps_replicas': 0, '_keep_checkpoint_max': 5, '_task_type': 'worker', '_global_id_in_cluster': 0, '_is_chief': True, '_cluster_spec': {'worker': ['localhost:5000', 'localhost:5001']}, '_model_dir': '/var/folders/gn/sjntndrs1fs22kfr302697mr0000gn/T/tmpeIE_x6', '_protocol': None, '_save_checkpoints_steps': None, '_keep_checkpoint_every_n_hours': 10000, '_service': None, '_session_config': allow_soft_placement: true
graph_options {
rewrite_options {
meta_optimizer_iterations: ONE
}
}
, '_tf_random_seed': None, '_save_summary_steps': 100, '_device_fn': None, '_experimental_distribute': DistributeConfig(train_distribute=<tensorflow.contrib.distribute.python.collective_all_reduce_strategy.CollectiveAllReduceStrategy object at 0x123f09810>, eval_distribute=None, remote_cluster={'worker': ['localhost:5000', 'localhost:5001']}), '_num_worker_replicas': 1, '_task_id': 0, '_log_step_count_steps': 100, '_evaluation_master': '', '_eval_distribute': None, '_train_distribute': <tensorflow.contrib.distribute.python.collective_all_reduce_strategy.CollectiveAllReduceStrategy object at 0x123f09810>, '_master': '', '_distribute_coordinator_mode': 'standalone_client'}
INFO:tensorflow:Running `train_and_evaluate` with Distribute Coordinator.
INFO:tensorflow:Running Distribute Coordinator with mode = 'standalone_client', cluster_spec = {'worker': ['localhost:5000', 'localhost:5001']}, task_type = None, task_id = None, environment = None, rpc_layer = 'grpc'
WARNING:tensorflow:`eval_strategy` is not passed in. No distribution strategy will be used for evaluation.
INFO:tensorflow:Multi-worker CollectiveAllReduceStrategy with cluster_spec = {'worker': ['localhost:5000', 'localhost:5001']}, task_type = 'worker', task_id = 0, num_workers = 2, local_devices = ['/job:worker/task:0']
INFO:tensorflow:Multi-worker CollectiveAllReduceStrategy with cluster_spec = {'worker': ['localhost:5000', 'localhost:5001']}, task_type = 'worker', task_id = 1, num_workers = 2, local_devices = ['/job:worker/task:1']
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Collective All-reduce invoked with batches size = 2, num_workers = 2
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Creating chief session creator with config: device_filters: "/job:worker/task:0"
allow_soft_placement: true
graph_options {
rewrite_options {
meta_optimizer_iterations: ONE
scoped_allocator_optimization: ON
scoped_allocator_opts {
enable_op: "CollectiveReduce"
}
}
}
isolate_session_state: true
experimental {
collective_group_leader: "/job:worker/replica:0/task:0"
}
INFO:tensorflow:Collective All-reduce invoked with batches size = 2, num_workers = 2
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Creating chief session creator with config: device_filters: "/job:worker/task:1"
allow_soft_placement: true
graph_options {
rewrite_options {
meta_optimizer_iterations: ONE
scoped_allocator_optimization: ON
scoped_allocator_opts {
enable_op: "CollectiveReduce"
}
}
}
isolate_session_state: true
experimental {
collective_group_leader: "/job:worker/replica:0/task:0"
}
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Initialize system
INFO:tensorflow:Initialize system
INFO:tensorflow:Saving checkpoints for 0 into /var/folders/gn/sjntndrs1fs22kfr302697mr0000gn/T/tmpeIE_x6/model.ckpt.
Exception in thread Thread-5:
Traceback (most recent call last):
File "/usr/local/Cellar/python@2/2.7.15_1/Frameworks/Python.framework/Versions/2.7/lib/python2.7/threading.py", line 801, in __bootstrap_inner
self.run()
File "/usr/local/Cellar/python@2/2.7.15_1/Frameworks/Python.framework/Versions/2.7/lib/python2.7/threading.py", line 754, in run
self.__target(*self.__args, **self.__kwargs)
File "/usr/local/lib/python2.7/site-packages/tensorflow/python/distribute/distribute_coordinator.py", line 344, in _run_single_worker
worker_fn(strategy)
File "/usr/local/lib/python2.7/site-packages/tensorflow/python/distribute/estimator_training.py", line 232, in _worker_fn
hooks=list(train_spec.hooks))
File "/usr/local/lib/python2.7/site-packages/tensorflow/python/estimator/estimator.py", line 355, in train
loss = self._train_model(input_fn, hooks, saving_listeners)
File "/usr/local/lib/python2.7/site-packages/tensorflow/python/estimator/estimator.py", line 1178, in _train_model
return self._train_model_distributed(input_fn, hooks, saving_listeners)
File "/usr/local/lib/python2.7/site-packages/tensorflow/python/estimator/estimator.py", line 1325, in _train_model_distributed
saving_listeners)
File "/usr/local/lib/python2.7/site-packages/tensorflow/python/estimator/estimator.py", line 1408, in _train_with_estimator_spec
_, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
File "/usr/local/lib/python2.7/site-packages/tensorflow/python/training/monitored_session.py", line 671, in run
run_metadata=run_metadata)
File "/usr/local/lib/python2.7/site-packages/tensorflow/python/training/monitored_session.py", line 1148, in run
run_metadata=run_metadata)
File "/usr/local/lib/python2.7/site-packages/tensorflow/python/training/monitored_session.py", line 1239, in run
raise six.reraise(*original_exc_info)
File "/usr/local/lib/python2.7/site-packages/tensorflow/python/training/monitored_session.py", line 1224, in run
return self._sess.run(*args, **kwargs)
File "/usr/local/lib/python2.7/site-packages/tensorflow/python/training/monitored_session.py", line 1296, in run
run_metadata=run_metadata)
File "/usr/local/lib/python2.7/site-packages/tensorflow/python/training/monitored_session.py", line 1076, in run
return self._sess.run(*args, **kwargs)
File "/usr/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 887, in run
run_metadata_ptr)
File "/usr/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1110, in _run
feed_dict_tensor, options, run_metadata)
File "/usr/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1286, in _do_run
run_metadata)
File "/usr/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1306, in _do_call
raise type(e)(node_def, op, message)
InternalError: ScopedAllocatorMgr not supported on device /job:worker/replica:0/task:1/device:CPU:0
[[{{node scoped_allocator_1}} = _ScopedAllocator[T=DT_FLOAT, expected_call_count=2, id=1, sa_name="scoped_allocator_1", shape=[17], shapes=[[1,1], [1]], _device="/job:worker/replica:0/task:1/device:CPU:0"]()]]
Exception in thread Thread-4:
Traceback (most recent call last):
File "/usr/local/Cellar/python@2/2.7.15_1/Frameworks/Python.framework/Versions/2.7/lib/python2.7/threading.py", line 801, in __bootstrap_inner
self.run()
File "/usr/local/Cellar/python@2/2.7.15_1/Frameworks/Python.framework/Versions/2.7/lib/python2.7/threading.py", line 754, in run
self.__target(*self.__args, **self.__kwargs)
File "/usr/local/lib/python2.7/site-packages/tensorflow/python/distribute/distribute_coordinator.py", line 344, in _run_single_worker
worker_fn(strategy)
File "/usr/local/lib/python2.7/site-packages/tensorflow/python/distribute/estimator_training.py", line 232, in _worker_fn
hooks=list(train_spec.hooks))
File "/usr/local/lib/python2.7/site-packages/tensorflow/python/estimator/estimator.py", line 355, in train
loss = self._train_model(input_fn, hooks, saving_listeners)
File "/usr/local/lib/python2.7/site-packages/tensorflow/python/estimator/estimator.py", line 1178, in _train_model
return self._train_model_distributed(input_fn, hooks, saving_listeners)
File "/usr/local/lib/python2.7/site-packages/tensorflow/python/estimator/estimator.py", line 1325, in _train_model_distributed
saving_listeners)
File "/usr/local/lib/python2.7/site-packages/tensorflow/python/estimator/estimator.py", line 1408, in _train_with_estimator_spec
_, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
File "/usr/local/lib/python2.7/site-packages/tensorflow/python/training/monitored_session.py", line 671, in run
run_metadata=run_metadata)
File "/usr/local/lib/python2.7/site-packages/tensorflow/python/training/monitored_session.py", line 1148, in run
run_metadata=run_metadata)
File "/usr/local/lib/python2.7/site-packages/tensorflow/python/training/monitored_session.py", line 1239, in run
raise six.reraise(*original_exc_info)
File "/usr/local/lib/python2.7/site-packages/tensorflow/python/training/monitored_session.py", line 1224, in run
return self._sess.run(*args, **kwargs)
File "/usr/local/lib/python2.7/site-packages/tensorflow/python/training/monitored_session.py", line 1296, in run
run_metadata=run_metadata)
File "/usr/local/lib/python2.7/site-packages/tensorflow/python/training/monitored_session.py", line 1076, in run
return self._sess.run(*args, **kwargs)
File "/usr/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 887, in run
run_metadata_ptr)
File "/usr/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1110, in _run
feed_dict_tensor, options, run_metadata)
File "/usr/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1286, in _do_run
run_metadata)
File "/usr/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1306, in _do_call
raise type(e)(node_def, op, message)
InternalError: ScopedAllocatorMgr not supported on device /job:worker/replica:0/task:0/device:CPU:0
[[{{node scoped_allocator_1}} = _ScopedAllocator[T=DT_FLOAT, expected_call_count=2, id=1, sa_name="scoped_allocator_1", shape=[17], shapes=[[1,1], [1]], _device="/job:worker/replica:0/task:0/device:CPU:0"]()]]
PS: thanks @yuefengz for today's introduction of multi-node distribution strategy in TF Roadshow 2018@Beijing 😆