@@ -196,76 +196,70 @@ def __repr__(self):
196196 return out
197197
198198
199- def lightning_hasattr (model , attribute ):
200- """ Special hasattr for lightning. Checks for attribute in model namespace,
201- the old hparams namespace/dict, and the datamodule. """
199+ def lightning_get_all_attr_holders (model , attribute ):
200+ """ Special attribute finding for lightning. Gets all of the objects or dicts that holds attribute.
201+ Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. """
202202 trainer = getattr (model , 'trainer' , None )
203203
204- attr = False
204+ holders = []
205+
205206 # Check if attribute in model
206207 if hasattr (model , attribute ):
207- attr = True
208+ holders .append (model )
209+
208210 # Check if attribute in model.hparams, either namespace or dict
209- elif hasattr (model , 'hparams' ):
210- if isinstance (model .hparams , dict ):
211- attr = attribute in model .hparams
212- else :
213- attr = hasattr (model .hparams , attribute )
211+ if hasattr (model , 'hparams' ):
212+ if attribute in model .hparams :
213+ holders .append (model .hparams )
214+
214215 # Check if the attribute in datamodule (datamodule gets registered in Trainer)
215- if not attr and trainer is not None :
216- attr = hasattr (trainer .datamodule , attribute )
216+ if trainer is not None and trainer . datamodule is not None and hasattr ( trainer . datamodule , attribute ) :
217+ holders . append (trainer .datamodule )
217218
218- return attr
219+ return holders
220+
221+
222+ def lightning_get_first_attr_holder (model , attribute ):
223+ """ Special attribute finding for lightning. Gets the object or dict that holds attribute, or None. Checks for attribute in model namespace,
224+ the old hparams namespace/dict, and the datamodule, returns the last one that has it. """
225+ holders = lightning_get_all_attr_holders (model , attribute )
226+ if len (holders ) == 0 :
227+ return None
228+ # using the last holder to preserve backwards compatibility
229+ return holders [- 1 ]
230+
231+
232+ def lightning_hasattr (model , attribute ):
233+ """ Special hasattr for lightning. Checks for attribute in model namespace,
234+ the old hparams namespace/dict, and the datamodule. """
235+ return lightning_get_first_attr_holder (model , attribute ) is not None
219236
220237
221238def lightning_getattr (model , attribute ):
222239 """ Special getattr for lightning. Checks for attribute in model namespace,
223240 the old hparams namespace/dict, and the datamodule. """
224- trainer = getattr (model , 'trainer' , None )
241+ holder = lightning_get_first_attr_holder (model , attribute )
242+ if holder is None :
243+ raise ValueError (f'{ attribute } is neither stored in the model namespace'
244+ ' nor the `hparams` namespace/dict, nor the datamodule.' )
225245
226- # Check if attribute in model
227- if hasattr (model , attribute ):
228- attr = getattr (model , attribute )
229- # Check if attribute in model.hparams, either namespace or dict
230- elif hasattr (model , 'hparams' ) and isinstance (model .hparams , dict ) and attribute in model .hparams :
231- attr = model .hparams [attribute ]
232- elif hasattr (model , 'hparams' ) and hasattr (model .hparams , attribute ):
233- attr = getattr (model .hparams , attribute )
234- # Check if the attribute in datamodule (datamodule gets registered in Trainer)
235- elif trainer is not None and trainer .datamodule is not None and hasattr (trainer .datamodule , attribute ):
236- attr = getattr (trainer .datamodule , attribute )
237- else :
238- raise ValueError (
239- f'The { attribute } is neither stored in the model namespace nor the `hparams` namespace/dict,'
240- ' nor the datamodule.'
241- )
242- return attr
246+ if isinstance (holder , dict ):
247+ return holder [attribute ]
248+ return getattr (holder , attribute )
243249
244250
245251def lightning_setattr (model , attribute , value ):
246252 """ Special setattr for lightning. Checks for attribute in model namespace
247253 and the old hparams namespace/dict.
248254 Will also set the attribute on datamodule, if it exists.
249255 """
250- if not lightning_hasattr (model , attribute ):
251- raise ValueError (
252- f'The { attribute } is neither stored in the model namespace nor the `hparams` namespace/dict,'
253- ' nor the datamodule.'
254- )
255-
256- trainer = getattr (model , 'trainer' , None )
257-
258- # Check if attribute in model
259- if hasattr (model , attribute ):
260- setattr (model , attribute , value )
261-
262- # Check if attribute in model.hparams, either namespace or dict
263- elif hasattr (model , 'hparams' ):
264- if isinstance (model .hparams , dict ):
265- model .hparams [attribute ] = value
256+ holders = lightning_get_all_attr_holders (model , attribute )
257+ if len (holders ) == 0 :
258+ raise ValueError (f'{ attribute } is neither stored in the model namespace'
259+ ' nor the `hparams` namespace/dict, nor the datamodule.' )
260+
261+ for holder in holders :
262+ if isinstance (holder , dict ):
263+ holder [attribute ] = value
266264 else :
267- setattr (model .hparams , attribute , value )
268-
269- # Check if the attribute in datamodule (datamodule gets registered in Trainer)
270- if trainer is not None and trainer .datamodule is not None and hasattr (trainer .datamodule , attribute ):
271- setattr (trainer .datamodule , attribute , value )
265+ setattr (holder , attribute , value )
0 commit comments