Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 49 additions & 28 deletions marimo/_ast/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,33 +277,40 @@ def _get_alias_name(
NB: We disallow `import *` because Python only allows
star imports at module-level, but we store cells as functions.
"""
if node.asname is None:
# Imported name without an "as" clause. Examples:
# import [a.b.c] - we define a
# from foo import [a] - we define a
# from foo import [*] - we don't define anything
#
# Note:
# Don't mangle - user has no control over package name
basename = node.name.split(".")[0]
# Fast-path: cache the split for node.name in all cases
asname = node.asname
if asname is None:
# Only split if we need to
name = node.name
if "." in name:
basename = name[: name.index(".")]
else:
basename = name
if basename == "*":
# Use the ImportFrom node's line number for consistency
line_num = (
import_node.lineno
if import_node and hasattr(import_node, "lineno")
else node.lineno
if hasattr(node, "lineno")
# Only try to get line numbers if we actually must raise
import_node_lineno = (
getattr(import_node, "lineno", None)
if import_node
else None
)
line = f"line {line_num}" if line_num else "line ..."
node_lineno = getattr(node, "lineno", None)
line_num = (
import_node_lineno
if import_node_lineno is not None
else node_lineno
)
line = (
f"line {line_num}" if line_num is not None else "line ..."
)
raise ImportStarError(
f"{line} SyntaxError: Importing symbols with `import *` "
"is not allowed in marimo."
)
return basename
else:
node.asname = self._if_local_then_mangle(node.asname)
return node.asname
asname_mangled = self._if_local_then_mangle(asname)
node.asname = asname_mangled
return asname_mangled

def _is_defined(self, identifier: str) -> bool:
"""Check if `identifier` is defined in any block."""
Expand Down Expand Up @@ -391,7 +398,9 @@ def _define(
Names created with the global keyword are added to the top-level
(global scope) block.
"""
block_idx = 0 if name in self.block_stack[-1].global_names else -1
# Avoid attribute lookup in loop: cache global_names
global_names = self.block_stack[-1].global_names
block_idx = 0 if name in global_names else -1
self._define_in_block(name, variable_data, block_idx=block_idx)
if node is not None:
self._on_def(node, name, self.block_stack)
Expand Down Expand Up @@ -1014,21 +1023,33 @@ def visit_Import(self, node: ast.Import) -> ast.Import:

def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.ImportFrom:
module = node.module if node.module is not None else ""
# we don't recurse into the alias nodes, since we define the
# aliases here
for alias_node in node.names:
variable_name = self._get_alias_name(alias_node, import_node=node)
module_dot = module + "."
names = node.names
# Move definitions out of loop for slight speedup
level = node.level
# Avoid attribute lookup in loop by binding function locally
_get_alias_name = self._get_alias_name
_define = self._define
ImportData_ = ImportData
VariableData_ = VariableData
# Precompute list of tuples to amortize Python loop overhead
variable_names = []
for alias_node in names:
variable_name = _get_alias_name(alias_node, import_node=node)
original_name = alias_node.name
self._define(
variable_names.append((variable_name, module_dot + original_name))
# Use single loop for _define calls
for variable_name, imported_symbol in variable_names:
_define(
None,
variable_name,
VariableData(
VariableData_(
kind="import",
import_data=ImportData(
import_data=ImportData_(
module=module,
definition=variable_name,
imported_symbol=module + "." + original_name,
import_level=node.level,
imported_symbol=imported_symbol,
import_level=level,
),
),
)
Expand Down