diff --git a/keras/engine/network.py b/keras/engine/network.py index c493f05b15b0..10972e53010c 100644 --- a/keras/engine/network.py +++ b/keras/engine/network.py @@ -139,14 +139,8 @@ def _base_init(self, name=None): def _init_graph_network(self, inputs, outputs, name=None): self._uses_inputs_arg = True # Normalize and set self.inputs, self.outputs. - if isinstance(inputs, (list, tuple)): - self.inputs = list(inputs) # Tensor or list of tensors. - else: - self.inputs = [inputs] - if isinstance(outputs, (list, tuple)): - self.outputs = list(outputs) - else: - self.outputs = [outputs] + self.inputs = to_list(inputs, allow_tuple=True) + self.outputs = to_list(outputs, allow_tuple=True) # User-provided argument validation. # Check for redundancy in inputs. diff --git a/keras/engine/training.py b/keras/engine/training.py index 796e8599a89b..1a6f1bdbc476 100644 --- a/keras/engine/training.py +++ b/keras/engine/training.py @@ -596,10 +596,7 @@ def _set_inputs(self, inputs, outputs=None, training=None): self._feed_inputs = [] self._feed_input_names = [] self._feed_input_shapes = [] - if isinstance(inputs, (list, tuple)): - inputs = list(inputs) - else: - inputs = [inputs] + inputs = to_list(inputs, allow_tuple=True) for i, v in enumerate(inputs): name = 'input_%d' % (i + 1) @@ -633,10 +630,7 @@ def _set_inputs(self, inputs, outputs=None, training=None): outputs = self.call(unpack_singleton(self.inputs), training=training) else: outputs = self.call(unpack_singleton(self.inputs)) - if isinstance(outputs, (list, tuple)): - outputs = list(outputs) - else: - outputs = [outputs] + outputs = to_list(outputs, allow_tuple=True) self.outputs = outputs self.output_names = [ 'output_%d' % (i + 1) for i in range(len(self.outputs))] @@ -704,10 +698,7 @@ def _standardize_user_data(self, x, 'You passed: y=' + str(y)) # Typecheck that all inputs are *either* value *or* symbolic. if y is not None: - if isinstance(y, (list, tuple)): - all_inputs += list(y) - else: - all_inputs.append(y) + all_inputs += to_list(y, allow_tuple=True) if any(K.is_tensor(v) for v in all_inputs): if not all(K.is_tensor(v) for v in all_inputs): raise ValueError('Do not pass inputs that mix Numpy ' @@ -716,8 +707,7 @@ def _standardize_user_data(self, x, '; y=' + str(y)) # Handle target tensors if any passed. - if not isinstance(y, (list, tuple)): - y = [y] + y = to_list(y, allow_tuple=True) target_tensors = [v for v in y if K.is_tensor(v)] if not target_tensors: target_tensors = None diff --git a/keras/layers/advanced_activations.py b/keras/layers/advanced_activations.py index 763944403953..f6cd51dcfe89 100644 --- a/keras/layers/advanced_activations.py +++ b/keras/layers/advanced_activations.py @@ -13,6 +13,7 @@ from ..engine.base_layer import InputSpec from .. import backend as K from ..legacy import interfaces +from ..utils.generic_utils import to_list class LeakyReLU(Layer): @@ -100,10 +101,8 @@ def __init__(self, alpha_initializer='zeros', self.alpha_constraint = constraints.get(alpha_constraint) if shared_axes is None: self.shared_axes = None - elif not isinstance(shared_axes, (list, tuple)): - self.shared_axes = [shared_axes] else: - self.shared_axes = list(shared_axes) + self.shared_axes = to_list(shared_axes, allow_tuple=True) def build(self, input_shape): param_shape = list(input_shape[1:]) diff --git a/keras/layers/convolutional_recurrent.py b/keras/layers/convolutional_recurrent.py index dc1f7bbe5a3f..5ac70c079a8d 100644 --- a/keras/layers/convolutional_recurrent.py +++ b/keras/layers/convolutional_recurrent.py @@ -21,6 +21,7 @@ from ..legacy.layers import Recurrent, ConvRecurrent2D from .recurrent import RNN from ..utils.generic_utils import has_arg +from ..utils.generic_utils import to_list from ..utils.generic_utils import transpose_shape @@ -387,10 +388,7 @@ def step(inputs, states): output._uses_learning_phase = True if self.return_state: - if not isinstance(states, (list, tuple)): - states = [states] - else: - states = list(states) + states = to_list(states, allow_tuple=True) return [output] + states else: return output @@ -443,8 +441,7 @@ def get_tuple_shape(nb_channels): K.set_value(self.states[0], np.zeros(get_tuple_shape(self.cell.state_size))) else: - if not isinstance(states, (list, tuple)): - states = [states] + states = to_list(states, allow_tuple=True) if len(states) != len(self.states): raise ValueError('Layer ' + self.name + ' expects ' + str(len(self.states)) + ' states, ' diff --git a/keras/layers/embeddings.py b/keras/layers/embeddings.py index 002ce5480dec..35b018d041f9 100644 --- a/keras/layers/embeddings.py +++ b/keras/layers/embeddings.py @@ -10,6 +10,7 @@ from .. import constraints from ..engine.base_layer import Layer from ..legacy import interfaces +from ..utils.generic_utils import to_list class Embedding(Layer): @@ -117,10 +118,7 @@ def compute_output_shape(self, input_shape): return input_shape + (self.output_dim,) else: # input_length can be tuple if input is 3D or higher - if isinstance(self.input_length, (list, tuple)): - in_lens = list(self.input_length) - else: - in_lens = [self.input_length] + in_lens = to_list(self.input_length, allow_tuple=True) if len(in_lens) != len(input_shape) - 1: raise ValueError('"input_length" is %s, but received input has shape %s' % (str(self.input_length), str(input_shape))) diff --git a/keras/layers/recurrent.py b/keras/layers/recurrent.py index c82e6a32c65b..201c68cd0efc 100644 --- a/keras/layers/recurrent.py +++ b/keras/layers/recurrent.py @@ -16,6 +16,7 @@ from ..engine.base_layer import Layer from ..engine.base_layer import InputSpec from ..utils.generic_utils import has_arg +from ..utils.generic_utils import to_list # Legacy support. from ..legacy.layers import Recurrent @@ -664,10 +665,7 @@ def step(inputs, states): state._uses_learning_phase = True if self.return_state: - if not isinstance(states, (list, tuple)): - states = [states] - else: - states = list(states) + states = to_list(states, allow_tuple=True) return [output] + states else: return output @@ -702,8 +700,7 @@ def reset_states(self, states=None): K.set_value(self.states[0], np.zeros((batch_size, self.cell.state_size))) else: - if not isinstance(states, (list, tuple)): - states = [states] + states = to_list(states, allow_tuple=True) if len(states) != len(self.states): raise ValueError('Layer ' + self.name + ' expects ' + str(len(self.states)) + ' states, ' diff --git a/keras/legacy/layers.py b/keras/legacy/layers.py index be869335bca1..85c4130b1cd4 100644 --- a/keras/legacy/layers.py +++ b/keras/legacy/layers.py @@ -508,8 +508,7 @@ def __call__(self, inputs, initial_state=None, **kwargs): if initial_state is None: return super(Recurrent, self).__call__(inputs, **kwargs) - if not isinstance(initial_state, (list, tuple)): - initial_state = [initial_state] + initial_state = to_list(initial_state, allow_tuple=True) is_keras_tensor = hasattr(initial_state[0], '_keras_history') for tensor in initial_state: @@ -602,10 +601,7 @@ def call(self, inputs, mask=None, training=None, initial_state=None): output = last_output if self.return_state: - if not isinstance(states, (list, tuple)): - states = [states] - else: - states = list(states) + states = to_list(states, allow_tuple=True) return [output] + states else: return output @@ -633,8 +629,7 @@ def reset_states(self, states=None): for state in self.states: K.set_value(state, np.zeros((batch_size, self.units))) else: - if not isinstance(states, (list, tuple)): - states = [states] + states = to_list(states, allow_tuple=True) if len(states) != len(self.states): raise ValueError('Layer ' + self.name + ' expects ' + str(len(self.states)) + ' states, ' diff --git a/keras/utils/generic_utils.py b/keras/utils/generic_utils.py index 864dbbaba1ec..9ea10a2b9724 100644 --- a/keras/utils/generic_utils.py +++ b/keras/utils/generic_utils.py @@ -444,7 +444,7 @@ def add(self, n, values=None): self.update(self._seen_so_far + n, values) -def to_list(x): +def to_list(x, allow_tuple=False): """Normalizes a list/tensor into a list. If a tensor is passed, we return @@ -452,12 +452,18 @@ def to_list(x): # Arguments x: target object to be normalized. + allow_tuple: If False and x is a tuple, + it will be converted into a list + with a single element (the tuple). + Else converts the tuple to a list. # Returns A list. """ if isinstance(x, list): return x + if allow_tuple and isinstance(x, tuple): + return list(x) return [x] @@ -483,10 +489,7 @@ def object_list_uid(object_list): def is_all_none(iterable_or_element): - if not isinstance(iterable_or_element, (list, tuple)): - iterable = [iterable_or_element] - else: - iterable = iterable_or_element + iterable = to_list(iterable_or_element, allow_tuple=True) for element in iterable: if element is not None: return False