Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 18 additions & 26 deletions marimo/_runtime/reload/autoreload.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,9 +402,7 @@ def append_obj(
name: str,
obj: object,
) -> bool:
in_module = (
hasattr(obj, "__module__") and obj.__module__ == module.__name__
)
in_module = getattr(obj, "__module__", None) == module.__name__
if not in_module:
return False

Expand All @@ -431,37 +429,27 @@ def superreload(
if old_objects is None:
old_objects = {}

module_dict = module.__dict__
mod_name = module.__name__

# collect old objects in the module
for name, obj in list(module.__dict__.items()):
if not append_obj(module, old_objects, name, obj):
continue
key = (module.__name__, name)
try:
old_objects.setdefault(key, []).append(weakref.ref(obj))
except TypeError:
pass
for name, obj in module_dict.items():
append_obj(module, old_objects, name, obj)

# reload module
old_dict: dict[str, Any] | None = None
try:
# clear namespace first from old cruft
old_dict = module.__dict__.copy()
old_name = module.__name__
module.__dict__.clear()
module.__dict__["__name__"] = old_name
module.__dict__["__loader__"] = old_dict["__loader__"]
old_dict = module_dict.copy()
module_dict.clear()
module_dict["__name__"] = mod_name
module_dict["__loader__"] = old_dict["__loader__"]
except (TypeError, AttributeError, KeyError):
pass

try:
module = reload(module)
except Exception as e:
# User introduced a SyntaxError, ModuleNotFoundError, etc -- they
# should be told, and module dict should not be restored, ie don't fail
# silently.
#
# It's possible that the module fails to reload for some other reason.
# In this case, too, the failure shouldn't be silent!
sys.stderr.write(
f"Error trying to reload module {module.__name__}: {str(e)} \n"
)
Expand All @@ -472,13 +460,17 @@ def superreload(
raise

# iterate over all objects and update functions & classes
for name, new_obj in list(module.__dict__.items()):
key = (module.__name__, name)
if key not in old_objects:
module_dict = module.__dict__
mod_name = module.__name__

for name, new_obj in module_dict.items():
key = (mod_name, name)
old_refs = old_objects.get(key)
if not old_refs:
continue

new_refs = []
for old_ref in old_objects[key]:
for old_ref in old_refs:
old_obj = old_ref()
if old_obj is None:
continue
Expand Down