@@ -207,7 +207,8 @@ def variables(self):
207207
208208 Returns:
209209 A sequence of variables for the current module (sorted by attribute
210- name) followed by variables from all submodules recursively (depth first).
210+ name) followed by variables from all submodules recursively (breadth
211+ first).
211212 """
212213 return tuple (self ._flatten (predicate = _IS_VARIABLE ))
213214
@@ -221,7 +222,8 @@ def trainable_variables(self):
221222
222223 Returns:
223224 A sequence of variables for the current module (sorted by attribute
224- name) followed by variables from all submodules recursively (depth first).
225+ name) followed by variables from all submodules recursively (breadth
226+ first).
225227 """
226228 return tuple (self ._flatten (predicate = _IS_TRAINABLE_VARIABLE ))
227229
@@ -249,7 +251,8 @@ def submodules(self):
249251 def _flatten (self ,
250252 recursive = True ,
251253 predicate = None ,
252- attribute_traversal_key = None ):
254+ attribute_traversal_key = None ,
255+ with_path = False ):
253256 """Flattened attribute values in sorted order by attribute name.
254257
255258 Modules are flattened by first walking their attributes in name order.
@@ -267,11 +270,15 @@ def _flatten(self,
267270 ...
268271 ... @property
269272 ... def tensors(self):
270- ... return tuple(self._flatten(predicate=is_tensor))
273+ ... return tuple(self._flatten(predicate=is_tensor, with_path=True ))
271274
272275 >>> foo = Foo()
273276 >>> foo.tensors
274- (<tf.Tensor...'a'>, <tf.Tensor...'b'>, ...'c'>, ...'d'>, ...'e'>)
277+ ((('x', 0), <tf.Tensor: ...'a'>),
278+ (('x', 1), <tf.Tensor: ...'b'>),
279+ (('y', 'i'), <tf.Tensor: ...'c'>),
280+ (('y', 'j'), <tf.Tensor: ...'d'>),
281+ (('z',), <tf.Tensor: ...'e'>))
275282
276283 `attribute_traversal_key` controls the order object properties are visited.
277284 If not set objects are visited in ascending order by name.
@@ -284,6 +291,10 @@ def _flatten(self,
284291 attribute_traversal_key: (Optional) Method to rekey object attributes
285292 before they are sorted. Contract is the same as `key` argument to
286293 builtin `sorted` and only applies to object properties.
294+ with_path: (Optional) Whether to include the path to the object as well
295+ as the object itself. If `with_path` is `True` then leaves will not be
296+ de-duplicated (e.g. if the same leaf instance is reachable via multiple
297+ modules then it will be yielded multiple times with different paths).
287298
288299 Returns:
289300 Flat generator for leaves of the current module and optionally all
@@ -297,7 +308,7 @@ def _flatten(self,
297308 recursive = recursive ,
298309 predicate = predicate ,
299310 attribute_traversal_key = attribute_traversal_key ,
300- seen = set () )
311+ with_path = with_path )
301312
302313 @classmethod
303314 def no_name_scope (cls , method ):
@@ -337,8 +348,20 @@ def camel_to_snake(value):
337348 return _CAMEL_TO_SNAKE_R .sub (r"_\1" , value ).lower ()
338349
339350
340- def _flatten_module (module , recursive , predicate , attribute_traversal_key ,
341- seen ):
351+ # AutoCheckpointable adds object attributes that users will not expect us to
352+ # include when flattening (these reference dependencies reachable via other
353+ # object attributes).
354+ AUTO_CHECKPOINTABLE_ATTRS = ("_unconditional_checkpoint_dependencies" ,
355+ "_unconditional_dependency_names" )
356+
357+
358+ def _flatten_module (module ,
359+ recursive ,
360+ predicate ,
361+ attribute_traversal_key ,
362+ with_path ,
363+ module_path = (),
364+ seen = None ):
342365 """Implementation of `flatten`."""
343366 if seen is None :
344367 seen = set ([id (module )])
@@ -347,25 +370,37 @@ def _flatten_module(module, recursive, predicate, attribute_traversal_key,
347370 submodules = []
348371
349372 for key in sorted (module_dict , key = attribute_traversal_key ):
350- for leaf in nest .flatten (module_dict [key ]):
351- leaf_id = id (leaf )
352- if leaf_id in seen :
353- continue
373+ if key in AUTO_CHECKPOINTABLE_ATTRS :
374+ continue
375+
376+ for leaf_path , leaf in nest .flatten_with_tuple_paths (module_dict [key ]):
377+ leaf_path = (key ,) + leaf_path
378+
379+ # TODO(tomhennigan) Handle cycles for `with_path=True` (e.g. `a.a = a`).
380+ if not with_path :
381+ leaf_id = id (leaf )
382+ if leaf_id in seen :
383+ continue
384+ seen .add (leaf_id )
354385
355- seen .add (leaf_id )
356386 if predicate (leaf ):
357- yield leaf
387+ if with_path :
388+ yield module_path + leaf_path , leaf
389+ else :
390+ yield leaf
358391
359392 if recursive and isinstance (leaf , Module ):
360393 # Walk direct properties first then recurse.
361- submodules .append (leaf )
394+ submodules .append (( module_path + leaf_path , leaf ) )
362395
363- for submodule in submodules :
396+ for submodule_path , submodule in submodules :
364397 subvalues = _flatten_module (
365398 submodule ,
366399 recursive = recursive ,
367400 predicate = predicate ,
368401 attribute_traversal_key = attribute_traversal_key ,
402+ with_path = with_path ,
403+ module_path = submodule_path ,
369404 seen = seen )
370405
371406 for subvalue in subvalues :
0 commit comments