diff --git a/marimo/_runtime/reload/autoreload.py b/marimo/_runtime/reload/autoreload.py index 61a1e4ced31..b4edfa1bb87 100644 --- a/marimo/_runtime/reload/autoreload.py +++ b/marimo/_runtime/reload/autoreload.py @@ -402,9 +402,7 @@ def append_obj( name: str, obj: object, ) -> bool: - in_module = ( - hasattr(obj, "__module__") and obj.__module__ == module.__name__ - ) + in_module = getattr(obj, "__module__", None) == module.__name__ if not in_module: return False @@ -431,37 +429,27 @@ def superreload( if old_objects is None: old_objects = {} + module_dict = module.__dict__ + mod_name = module.__name__ + # collect old objects in the module - for name, obj in list(module.__dict__.items()): - if not append_obj(module, old_objects, name, obj): - continue - key = (module.__name__, name) - try: - old_objects.setdefault(key, []).append(weakref.ref(obj)) - except TypeError: - pass + for name, obj in module_dict.items(): + append_obj(module, old_objects, name, obj) # reload module old_dict: dict[str, Any] | None = None try: # clear namespace first from old cruft - old_dict = module.__dict__.copy() - old_name = module.__name__ - module.__dict__.clear() - module.__dict__["__name__"] = old_name - module.__dict__["__loader__"] = old_dict["__loader__"] + old_dict = module_dict.copy() + module_dict.clear() + module_dict["__name__"] = mod_name + module_dict["__loader__"] = old_dict["__loader__"] except (TypeError, AttributeError, KeyError): pass try: module = reload(module) except Exception as e: - # User introduced a SyntaxError, ModuleNotFoundError, etc -- they - # should be told, and module dict should not be restored, ie don't fail - # silently. - # - # It's possible that the module fails to reload for some other reason. - # In this case, too, the failure shouldn't be silent! sys.stderr.write( f"Error trying to reload module {module.__name__}: {str(e)} \n" ) @@ -472,13 +460,17 @@ def superreload( raise # iterate over all objects and update functions & classes - for name, new_obj in list(module.__dict__.items()): - key = (module.__name__, name) - if key not in old_objects: + module_dict = module.__dict__ + mod_name = module.__name__ + + for name, new_obj in module_dict.items(): + key = (mod_name, name) + old_refs = old_objects.get(key) + if not old_refs: continue new_refs = [] - for old_ref in old_objects[key]: + for old_ref in old_refs: old_obj = old_ref() if old_obj is None: continue