From 9b4ede56a3c6f5c38ca5272433de2adae152b6b8 Mon Sep 17 00:00:00 2001 From: Alvin Ryanputra Date: Wed, 16 Apr 2025 14:14:05 -0400 Subject: [PATCH 1/6] initial implementation --- .../code_directories/retriever/import_test.py | 5 + codeflash/context/code_context_extractor.py | 10 +- .../context/unused_definition_remover.py | 476 ++++++++++++++++++ tests/test_code_context_extractor.py | 43 +- tests/test_remove_unused_definitions.py | 416 +++++++++++++++ 5 files changed, 923 insertions(+), 27 deletions(-) create mode 100644 code_to_optimize/code_directories/retriever/import_test.py create mode 100644 codeflash/context/unused_definition_remover.py create mode 100644 tests/test_remove_unused_definitions.py diff --git a/code_to_optimize/code_directories/retriever/import_test.py b/code_to_optimize/code_directories/retriever/import_test.py new file mode 100644 index 000000000..7f12f0a89 --- /dev/null +++ b/code_to_optimize/code_directories/retriever/import_test.py @@ -0,0 +1,5 @@ + +import code_to_optimize.code_directories.retriever.main + +def function_to_optimize(): + return code_to_optimize.code_directories.retriever.main.fetch_and_transform_data() diff --git a/codeflash/context/code_context_extractor.py b/codeflash/context/code_context_extractor.py index 516f3c94e..c989dc708 100644 --- a/codeflash/context/code_context_extractor.py +++ b/codeflash/context/code_context_extractor.py @@ -14,6 +14,7 @@ from codeflash.cli_cmds.console import logger from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects from codeflash.code_utils.code_utils import get_qualified_name, path_belongs_to_site_packages +from codeflash.context.unused_definition_remover import remove_unused_definitions_by_function_names from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.models.models import ( CodeContextType, @@ -189,7 +190,7 @@ def extract_code_string_context_from_files( helpers_of_helpers_qualified_names, remove_docstrings, ) - + code_context = remove_unused_definitions_by_function_names(code_context, qualified_function_names | helpers_of_helpers_qualified_names) except ValueError as e: logger.debug(f"Error while getting read-only code: {e}") continue @@ -217,6 +218,7 @@ def extract_code_string_context_from_files( code_context = parse_code_and_prune_cst( original_code, code_context_type, set(), qualified_helper_function_names, remove_docstrings ) + code_context = remove_unused_definitions_by_function_names(code_context, qualified_helper_function_names) except ValueError as e: logger.debug(f"Error while getting read-only code: {e}") continue @@ -290,6 +292,9 @@ def extract_code_markdown_context_from_files( helpers_of_helpers_qualified_names, remove_docstrings, ) + code_context = remove_unused_definitions_by_function_names( + code_context, qualified_function_names | helpers_of_helpers_qualified_names + ) except ValueError as e: logger.debug(f"Error while getting read-only code: {e}") @@ -321,6 +326,9 @@ def extract_code_markdown_context_from_files( code_context = parse_code_and_prune_cst( original_code, code_context_type, set(), qualified_helper_function_names, remove_docstrings ) + code_context = remove_unused_definitions_by_function_names( + code_context, qualified_helper_function_names + ) except ValueError as e: logger.debug(f"Error while getting read-only code: {e}") continue diff --git a/codeflash/context/unused_definition_remover.py b/codeflash/context/unused_definition_remover.py new file mode 100644 index 000000000..82a805259 --- /dev/null +++ b/codeflash/context/unused_definition_remover.py @@ -0,0 +1,476 @@ +from __future__ import annotations + +from dataclasses import dataclass, field + +import libcst as cst + + +@dataclass +class UsageInfo: + """Information about a name and its usage.""" + + name: str + used_by_qualified_function: bool = False + dependencies: set[str] = field(default_factory=set) + + +def extract_names_from_targets(target: cst.CSTNode) -> list[str]: + """Extract all variable names from a target node, including from tuple unpacking.""" + names = [] + + # Handle a simple name + if isinstance(target, cst.Name): + names.append(target.value) + + # Handle any node with a value attribute (StarredElement, etc.) + elif hasattr(target, "value"): + names.extend(extract_names_from_targets(target.value)) + + # Handle any node with elements attribute (tuples, lists, etc.) + elif hasattr(target, "elements"): + for element in target.elements: + # Recursive call for each element + names.extend(extract_names_from_targets(element)) + + return names + + +def collect_top_level_definitions(node: cst.CSTNode, definitions: dict[str, UsageInfo] = None) -> dict[str, UsageInfo]: + """Recursively collect all top-level variable, function, and class definitions.""" + if definitions is None: + definitions = {} + + # Handle top-level function definitions + if isinstance(node, cst.FunctionDef): + name = node.name.value + definitions[name] = UsageInfo( + name=name, + used_by_qualified_function=False, # Will be marked later if in qualified functions + ) + return definitions + + # Handle top-level class definitions + if isinstance(node, cst.ClassDef): + name = node.name.value + definitions[name] = UsageInfo(name=name) + + # Also collect method definitions within the class + if hasattr(node, "body") and isinstance(node.body, cst.IndentedBlock): + for statement in node.body.body: + if isinstance(statement, cst.FunctionDef): + method_name = f"{name}.{statement.name.value}" + definitions[method_name] = UsageInfo(name=method_name) + + return definitions + + # Handle top-level variable assignments + if isinstance(node, cst.Assign): + for target in node.targets: + names = extract_names_from_targets(target.target) + for name in names: + definitions[name] = UsageInfo(name=name) + return definitions + + if isinstance(node, cst.AnnAssign | cst.AugAssign): + if isinstance(node.target, cst.Name): + name = node.target.value + definitions[name] = UsageInfo(name=name) + else: + names = extract_names_from_targets(node.target) + for name in names: + definitions[name] = UsageInfo(name=name) + return definitions + + # Recursively process children. Takes care of top level assignments in if/else/while/for blocks + section_names = get_section_names(node) + + if section_names: + for section in section_names: + original_content = getattr(node, section, None) + # If section contains a list of nodes + if isinstance(original_content, list | tuple): + for child in original_content: + collect_top_level_definitions(child, definitions) + # If section contains a single node + elif original_content is not None: + collect_top_level_definitions(original_content, definitions) + + return definitions + + +def get_section_names(node: cst.CSTNode) -> list[str]: + """Return the section attribute names (e.g., body, orelse) for a given node if they exist.""" + possible_sections = ["body", "orelse", "finalbody", "handlers"] + return [sec for sec in possible_sections if hasattr(node, sec)] + + +class DependencyCollector(cst.CSTVisitor): + """Collects dependencies between definitions using the visitor pattern with depth tracking.""" + + def __init__(self, definitions: dict[str, UsageInfo]) -> None: + super().__init__() + self.definitions = definitions + # Track function and class depths + self.function_depth = 0 + self.class_depth = 0 + # Track top-level qualified names + self.current_top_level_name = "" + self.current_class = "" + # Track if we're processing a top-level variable + self.processing_variable = False + self.current_variable_names = set() + + def visit_FunctionDef(self, node: cst.FunctionDef) -> None: + function_name = node.name.value + + if self.function_depth == 0: + # This is a top-level function + if self.class_depth > 0: + # If inside a class, we're now tracking dependencies at the class level + self.current_top_level_name = f"{self.current_class}.{function_name}" + else: + # Regular top-level function + self.current_top_level_name = function_name + + # Check parameter type annotations for dependencies + if hasattr(node, "params") and node.params: + for param in node.params.params: + if param.annotation: + # Visit the annotation to extract dependencies + self._collect_annotation_dependencies(param.annotation) + + self.function_depth += 1 + + def _collect_annotation_dependencies(self, annotation: cst.Annotation) -> None: + """Extract dependencies from type annotations""" + if hasattr(annotation, "annotation"): + # Extract names from annotation (could be Name, Attribute, Subscript, etc.) + self._extract_names_from_annotation(annotation.annotation) + + def _extract_names_from_annotation(self, node: cst.CSTNode) -> None: + """Extract names from a type annotation node""" + # Simple name reference like 'int', 'str', or custom type + if isinstance(node, cst.Name): + name = node.value + if name in self.definitions and name != self.current_top_level_name and self.current_top_level_name: + self.definitions[self.current_top_level_name].dependencies.add(name) + + # Handle compound annotations like List[int], Dict[str, CustomType], etc. + elif isinstance(node, cst.Subscript): + if hasattr(node, "value"): + self._extract_names_from_annotation(node.value) + if hasattr(node, "slice"): + for slice_item in node.slice: + if hasattr(slice_item, "slice"): + self._extract_names_from_annotation(slice_item.slice) + + # Handle attribute access like module.Type + elif isinstance(node, cst.Attribute): + if hasattr(node, "value"): + self._extract_names_from_annotation(node.value) + # No need to check the attribute name itself as it's likely not a top-level definition + + def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: + self.function_depth -= 1 + + if self.function_depth == 0 and self.class_depth == 0: + # Exiting top-level function that's not in a class + self.current_top_level_name = "" + + def visit_ClassDef(self, node: cst.ClassDef) -> None: + class_name = node.name.value + + if self.class_depth == 0: + # This is a top-level class + self.current_class = class_name + self.current_top_level_name = class_name + + self.class_depth += 1 + + def leave_ClassDef(self, original_node: cst.ClassDef) -> None: + self.class_depth -= 1 + + if self.class_depth == 0: + # Exiting top-level class + self.current_class = "" + self.current_top_level_name = "" + + def visit_Assign(self, node: cst.Assign) -> None: + # Only handle top-level assignments + if self.function_depth == 0 and self.class_depth == 0: + for target in node.targets: + # Extract all variable names from the target + names = extract_names_from_targets(target.target) + + # Check if any of these names are top-level definitions we're tracking + tracked_names = [name for name in names if name in self.definitions] + if tracked_names: + self.processing_variable = True + self.current_variable_names.update(tracked_names) + # Use the first tracked name as the current top-level name (for dependency tracking) + self.current_top_level_name = tracked_names[0] + + def leave_Assign(self, original_node: cst.Assign) -> None: + if self.processing_variable: + self.processing_variable = False + self.current_variable_names.clear() + self.current_top_level_name = "" + + def visit_AnnAssign(self, node: cst.AnnAssign) -> None: + # Extract names from the variable annotations + if hasattr(node, "annotation") and node.annotation: + # First mark we're processing a variable to avoid recording it as a dependency of itself + self.processing_variable = True + if isinstance(node.target, cst.Name): + self.current_variable_names.add(node.target.value) + else: + self.current_variable_names.update(extract_names_from_targets(node.target)) + + # Process the annotation + self._collect_annotation_dependencies(node.annotation) + + # Reset processing state + self.processing_variable = False + self.current_variable_names.clear() + + def visit_Name(self, node: cst.Name) -> None: + name = node.value + + # Skip if we're not inside a tracked definition + if not self.current_top_level_name or self.current_top_level_name not in self.definitions: + return + + # Skip if we're looking at the variable name itself in an assignment + if self.processing_variable and name in self.current_variable_names: + return + + # Check if name is a top-level definition we're tracking + if name in self.definitions and name != self.current_top_level_name: + self.definitions[self.current_top_level_name].dependencies.add(name) + + +class QualifiedFunctionUsageMarker: + """Marks definitions that are used by specific qualified functions.""" + + def __init__(self, definitions: dict[str, UsageInfo], qualified_function_names: set[str]) -> None: + self.definitions = definitions + self.qualified_function_names = qualified_function_names + self.expanded_qualified_functions = self._expand_qualified_functions() + + def _expand_qualified_functions(self) -> set[str]: + """Expand the qualified function names to include related methods.""" + expanded = set(self.qualified_function_names) + + # Find class methods and add their containing classes and dunder methods + for qualified_name in list(self.qualified_function_names): + if "." in qualified_name: + class_name, method_name = qualified_name.split(".", 1) + + # Add the class itself + expanded.add(class_name) + + # Add all dunder methods of the class + for name in self.definitions: + if name.startswith(f"{class_name}.__") and name.endswith("__"): + expanded.add(name) + + return expanded + + def mark_used_definitions(self) -> None: + """Find all qualified functions and mark them and their dependencies as used.""" + # First identify all specified functions (including expanded ones) + functions_to_mark = [name for name in self.expanded_qualified_functions if name in self.definitions] + + # For each specified function, mark it and all its dependencies as used + for func_name in functions_to_mark: + self.definitions[func_name].used_by_qualified_function = True + for dep in self.definitions[func_name].dependencies: + self.mark_as_used_recursively(dep) + + def mark_as_used_recursively(self, name: str) -> None: + """Mark a name and all its dependencies as used recursively.""" + if name not in self.definitions: + return + + if self.definitions[name].used_by_qualified_function: + return # Already marked + + self.definitions[name].used_by_qualified_function = True + + # Mark all dependencies as used + for dep in self.definitions[name].dependencies: + self.mark_as_used_recursively(dep) + + +def remove_unused_definitions_recursively( + node: cst.CSTNode, definitions: dict[str, UsageInfo] +) -> tuple[cst.CSTNode | None, bool]: + """Recursively filter the node to remove unused definitions. + + Args: + node: The CST node to process + definitions: Dictionary of definition info + + Returns: + (filtered_node, used_by_function): + filtered_node: The modified CST node or None if it should be removed + used_by_function: True if this node or any child is used by qualified functions + + """ + # Skip import statements + if isinstance(node, cst.Import | cst.ImportFrom): + return node, True + + # Never remove function definitions + if isinstance(node, cst.FunctionDef): + return node, True + + # Never remove class definitions + if isinstance(node, cst.ClassDef): + class_name = node.name.value + + # Check if any methods or variables in this class are used + method_or_var_used = False + class_has_dependencies = False + + # Check if class itself is marked as used + if class_name in definitions and definitions[class_name].used_by_qualified_function: + class_has_dependencies = True + + if hasattr(node, "body") and isinstance(node.body, cst.IndentedBlock): + updates = {} + new_statements = [] + + for statement in node.body.body: + # Keep all function definitions + if isinstance(statement, cst.FunctionDef): + method_name = f"{class_name}.{statement.name.value}" + if method_name in definitions and definitions[method_name].used_by_qualified_function: + method_or_var_used = True + new_statements.append(statement) + # Only process variable assignments + elif isinstance(statement, (cst.Assign, cst.AnnAssign, cst.AugAssign)): + var_used = False + + # Check if any variable in this assignment is used + if isinstance(statement, cst.Assign): + for target in statement.targets: + names = extract_names_from_targets(target.target) + for name in names: + class_var_name = f"{class_name}.{name}" + if class_var_name in definitions and definitions[class_var_name].used_by_qualified_function: + var_used = True + method_or_var_used = True + break + elif isinstance(statement, (cst.AnnAssign, cst.AugAssign)): + names = extract_names_from_targets(statement.target) + for name in names: + class_var_name = f"{class_name}.{name}" + if class_var_name in definitions and definitions[class_var_name].used_by_qualified_function: + var_used = True + method_or_var_used = True + break + + if var_used or class_has_dependencies: + new_statements.append(statement) + else: + # Keep all other statements in the class + new_statements.append(statement) + + # Update the class body + new_body = node.body.with_changes(body=new_statements) + updates["body"] = new_body + + return node.with_changes(**updates), True + + return node, method_or_var_used or class_has_dependencies + + # Handle assignments (Assign and AnnAssign) + if isinstance(node, cst.Assign): + for target in node.targets: + names = extract_names_from_targets(target.target) + for name in names: + if name in definitions and definitions[name].used_by_qualified_function: + return node, True + return None, False + + if isinstance(node, cst.AnnAssign | cst.AugAssign): + names = extract_names_from_targets(node.target) + for name in names: + if name in definitions and definitions[name].used_by_qualified_function: + return node, True + return None, False + + # For other nodes, recursively process children + section_names = get_section_names(node) + if not section_names: + return node, False + + updates = {} + found_used = False + + for section in section_names: + original_content = getattr(node, section, None) + if isinstance(original_content, list | tuple): + new_children = [] + section_found_used = False + + for child in original_content: + filtered, used = remove_unused_definitions_recursively(child, definitions) + if filtered: + new_children.append(filtered) + section_found_used |= used + + if new_children or section_found_used: + found_used |= section_found_used + updates[section] = new_children + elif original_content is not None: + filtered, used = remove_unused_definitions_recursively(original_content, definitions) + found_used |= used + if filtered: + updates[section] = filtered + if not found_used: + return None, False + if updates: + return node.with_changes(**updates), found_used + + return node, False + + +def remove_unused_definitions_by_function_names(code: str, qualified_function_names: set[str]) -> str: + """Analyze a file and remove top level definitions not used by specified functions. + + Top level definitions, in this context, are only classes, variables or functions. + If a class is referenced by a qualified function, we keep the entire class. + + Args: + code: The code to process + qualified_function_names: Set of function names to keep. For methods, use format 'classname.methodname' + + """ + module = cst.parse_module(code) + # Collect all definitions (top level classes, variables or function) + definitions = collect_top_level_definitions(module) + + # Collect dependencies between definitions using the visitor pattern + dependency_collector = DependencyCollector(definitions) + module.visit(dependency_collector) + + # Mark definitions used by specified functions, and their dependencies recursively + usage_marker = QualifiedFunctionUsageMarker(definitions, qualified_function_names) + usage_marker.mark_used_definitions() + + # Apply the recursive removal transformation + modified_module, _ = remove_unused_definitions_recursively(module, definitions) + + return modified_module.code if modified_module else "" + + +def print_definitions(definitions: dict[str, UsageInfo]) -> None: + """Print information about each definition without the complex node object, used for debugging.""" + print(f"Found {len(definitions)} definitions:") + for name, info in sorted(definitions.items()): + print(f" - Name: {name}") + print(f" Used by qualified function: {info.used_by_qualified_function}") + print(f" Dependencies: {', '.join(sorted(info.dependencies)) if info.dependencies else 'None'}") + print() diff --git a/tests/test_code_context_extractor.py b/tests/test_code_context_extractor.py index 069a8eb19..da3f8350a 100644 --- a/tests/test_code_context_extractor.py +++ b/tests/test_code_context_extractor.py @@ -929,9 +929,6 @@ def fetch_and_process_data(): """ expected_read_only_context = f""" ```python:{path_to_utils.relative_to(project_root)} -GLOBAL_VAR = 10 - - class DataProcessor: \"\"\"A class for processing data.\"\"\" @@ -941,11 +938,6 @@ def __repr__(self) -> str: \"\"\"Return a string representation of the DataProcessor.\"\"\" return f"DataProcessor(default_prefix={{self.default_prefix!r}})" ``` -```python:{path_to_file.relative_to(project_root)} -if __name__ == "__main__": - result = fetch_and_process_data() - print("Processed data:", result) -``` """ assert read_write_context.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() @@ -1006,9 +998,6 @@ def fetch_and_transform_data(): """ expected_read_only_context = f""" ```python:{path_to_utils.relative_to(project_root)} -GLOBAL_VAR = 10 - - class DataProcessor: \"\"\"A class for processing data.\"\"\" @@ -1018,11 +1007,6 @@ def __repr__(self) -> str: \"\"\"Return a string representation of the DataProcessor.\"\"\" return f"DataProcessor(default_prefix={{self.default_prefix!r}})" ``` -```python:{path_to_file.relative_to(project_root)} -if __name__ == "__main__": - result = fetch_and_process_data() - print("Processed data:", result) -``` ```python:{path_to_transform_utils.relative_to(project_root)} class DataTransformer: @@ -1084,9 +1068,6 @@ def transform(self, data): return self.data ``` ```python:{path_to_utils.relative_to(project_root)} -GLOBAL_VAR = 10 - - class DataProcessor: \"\"\"A class for processing data.\"\"\" @@ -1147,9 +1128,6 @@ def update_data(data): return data + " updated" ``` ```python:{path_to_utils.relative_to(project_root)} -GLOBAL_VAR = 10 - - class DataProcessor: \"\"\"A class for processing data.\"\"\" @@ -1252,9 +1230,6 @@ def circular_dependency(self, data): """ expected_read_only_context = f""" ```python:{path_to_utils.relative_to(project_root)} -GLOBAL_VAR = 10 - - class DataProcessor: \"\"\"A class for processing data.\"\"\" @@ -1322,4 +1297,20 @@ def outside_method(): ``` """ assert read_write_context.strip() == expected_read_write_context.strip() - assert read_only_context.strip() == expected_read_only_context.strip() \ No newline at end of file + assert read_only_context.strip() == expected_read_only_context.strip() + +def test_direct_module_import() -> None: + project_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever" + path_to_main = project_root / "main.py" + path_to_fto = project_root / "import_test.py" + function_to_optimize = FunctionToOptimize( + function_name="function_to_optimize", + file_path=str(path_to_fto), + parents=[], + starting_line=None, + ending_line=None, + ) + + code_ctx = get_code_optimization_context(function_to_optimize, project_root) + read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code + print(read_only_context.strip()) \ No newline at end of file diff --git a/tests/test_remove_unused_definitions.py b/tests/test_remove_unused_definitions.py new file mode 100644 index 000000000..b2bd19f56 --- /dev/null +++ b/tests/test_remove_unused_definitions.py @@ -0,0 +1,416 @@ +import libcst as cst + +from codeflash.context.unused_definition_remover import remove_unused_definitions_by_function_names + + +def test_variable_removal_only() -> None: + """Test that only variables not used by specified functions are removed, not functions.""" + code = """ +def main_function(): + return USED_CONSTANT + 10 + +def helper_function(): + return 42 + +USED_CONSTANT = 42 +UNUSED_CONSTANT = 123 + +def another_function(): + return UNUSED_CONSTANT +""" + + expected = """ +def main_function(): + return USED_CONSTANT + 10 + +def helper_function(): + return 42 + +USED_CONSTANT = 42 + +def another_function(): + return UNUSED_CONSTANT +""" + + qualified_functions = {"main_function"} + result = remove_unused_definitions_by_function_names(code, qualified_functions) + # Normalize whitespace for comparison + assert result.strip() == expected.strip() + + +def test_class_variable_removal() -> None: + """Test that only class variables not used by specified functions are removed, not methods.""" + code = """ +class MyClass: + CLASS_USED = "used value" + CLASS_UNUSED = "unused value" + + def __init__(self): + self.value = self.CLASS_USED + self.other = self.CLASS_UNUSED + + def used_method(self): + return self.value + + def unused_method(self): + return "Not used but not removed" + +GLOBAL_USED = "global used" +GLOBAL_UNUSED = "global unused" + +def helper_function(): + return MyClass().used_method() + GLOBAL_USED +""" + + expected = """ +class MyClass: + CLASS_USED = "used value" + CLASS_UNUSED = "unused value" + + def __init__(self): + self.value = self.CLASS_USED + self.other = self.CLASS_UNUSED + + def used_method(self): + return self.value + + def unused_method(self): + return "Not used but not removed" + +GLOBAL_USED = "global used" + +def helper_function(): + return MyClass().used_method() + GLOBAL_USED +""" + + qualified_functions = {"helper_function"} + result = remove_unused_definitions_by_function_names(code, qualified_functions) + # Normalize whitespace for comparison + assert result.strip() == expected.strip() + + +def test_complex_variable_dependencies() -> None: + """Test that only variables with complex dependencies are properly handled.""" + code = """ +def main_function(): + return DIRECT_DEPENDENCY + +def unused_function(): + return "Not used but not removed" + +DIRECT_DEPENDENCY = INDIRECT_DEPENDENCY + "_suffix" +INDIRECT_DEPENDENCY = "base value" +UNUSED_VARIABLE = "This should be removed" + +TUPLE_USED, TUPLE_UNUSED = ("used", "unused") + +def tuple_user(): + return TUPLE_USED +""" + + expected = """ +def main_function(): + return DIRECT_DEPENDENCY + +def unused_function(): + return "Not used but not removed" + +DIRECT_DEPENDENCY = INDIRECT_DEPENDENCY + "_suffix" +INDIRECT_DEPENDENCY = "base value" + +def tuple_user(): + return TUPLE_USED +""" + + qualified_functions = {"main_function"} + result = remove_unused_definitions_by_function_names(code, qualified_functions) + assert result.strip() == expected.strip() + + +def test_type_annotation_usage() -> None: + """Test that variables used in type annotations are considered used.""" + code = """ +# Type definition +CustomType = int +UnusedType = str + +def main_function(param: CustomType) -> CustomType: + return param + 10 + +def unused_function(param: UnusedType) -> UnusedType: + return param + " suffix" + +UNUSED_CONSTANT = 123 +""" + + expected = """ +# Type definition +CustomType = int + +def main_function(param: CustomType) -> CustomType: + return param + 10 + +def unused_function(param: UnusedType) -> UnusedType: + return param + " suffix" + +""" + + qualified_functions = {"main_function"} + result = remove_unused_definitions_by_function_names(code, qualified_functions) + # Normalize whitespace for comparison + assert result.strip() == expected.strip() + + +def test_class_method_with_dunder_methods() -> None: + """Test that when a class method is used, dunder methods of that class are preserved.""" + code = """ +class MyClass: + CLASS_VAR = "class variable" + UNUSED_VAR = GLOBAL_VAR_2 + + def __init__(self, value): + self.value = GLOBAL_VAR + + def __str__(self): + return f"MyClass({self.value})" + + def target_method(self): + return self.value * 2 + + def unused_method(self): + return "Not used" + +GLOBAL_VAR = "global" +GLOBAL_VAR_2 = "global" +UNUSED_GLOBAL = "unused global" + +def helper_function(): + obj = MyClass(5) + return obj.target_method() +""" + + expected = """ +class MyClass: + CLASS_VAR = "class variable" + UNUSED_VAR = GLOBAL_VAR_2 + + def __init__(self, value): + self.value = GLOBAL_VAR + + def __str__(self): + return f"MyClass({self.value})" + + def target_method(self): + return self.value * 2 + + def unused_method(self): + return "Not used" + +GLOBAL_VAR = "global" +GLOBAL_VAR_2 = "global" + +def helper_function(): + obj = MyClass(5) + return obj.target_method() +""" + + qualified_functions = {"MyClass.target_method"} + result = remove_unused_definitions_by_function_names(code, qualified_functions) + # Normalize whitespace for comparison + assert result.strip() == expected.strip() + + +def test_complex_type_annotations() -> None: + """Test complex type annotations with nested types.""" + code = """ +from typing import List, Dict, Optional + +# Type aliases +ItemType = Dict[str, int] +ResultType = List[ItemType] +UnusedType = Optional[str] + +def process_data(items: ResultType) -> int: + total = 0 + for item in items: + for key, value in item.items(): + total += value + return total + +def unused_function(param: UnusedType) -> None: + pass + +# Variables +SAMPLE_DATA: ResultType = [{"a": 1, "b": 2}] +UNUSED_DATA: UnusedType = None +""" + + expected = """ +from typing import List, Dict, Optional + +# Type aliases +ItemType = Dict[str, int] +ResultType = List[ItemType] + +def process_data(items: ResultType) -> int: + total = 0 + for item in items: + for key, value in item.items(): + total += value + return total + +def unused_function(param: UnusedType) -> None: + pass +""" + + qualified_functions = {"process_data"} + result = remove_unused_definitions_by_function_names(code, qualified_functions) + assert result.strip() == expected.strip() + + +def test_try_except_finally_variables() -> None: + """Test handling of variables defined in try-except-finally blocks.""" + code = """ +import math +import os + +# Top-level try-except that defines variables +try: + MATH_CONSTANT = math.pi + USED_ERROR_MSG = "An error occurred" + UNUSED_CONST = 42 +except ImportError: + MATH_CONSTANT = 3.14 + USED_ERROR_MSG = "Math module not available" + UNUSED_CONST = 0 +finally: + CLEANUP_FLAG = True + UNUSED_CLEANUP = "Not used" + +def use_constants(): + return f"Pi is approximately {MATH_CONSTANT}, message: {USED_ERROR_MSG}" + +def use_cleanup(): + if CLEANUP_FLAG: + return "Cleanup performed" + return "No cleanup" + +def unused_function(): + return UNUSED_CONST +""" + + expected = """ +import math +import os + +# Top-level try-except that defines variables +try: + MATH_CONSTANT = math.pi + USED_ERROR_MSG = "An error occurred" +except ImportError: + MATH_CONSTANT = 3.14 + USED_ERROR_MSG = "Math module not available" +finally: + CLEANUP_FLAG = True + +def use_constants(): + return f"Pi is approximately {MATH_CONSTANT}, message: {USED_ERROR_MSG}" + +def use_cleanup(): + if CLEANUP_FLAG: + return "Cleanup performed" + return "No cleanup" + +def unused_function(): + return UNUSED_CONST +""" + + qualified_functions = {"use_constants", "use_cleanup"} + result = remove_unused_definitions_by_function_names(code, qualified_functions) + assert result.strip() == expected.strip() + +def test_conditional_and_loop_variables() -> None: + """Test handling of variables defined in if-else and while loops.""" + code = """ +import sys +import platform + +# Top-level if-else block defining variables +if sys.platform.startswith('win'): + OS_TYPE = "Windows" + OS_SEP = "" + UNUSED_WIN_VAR = "Unused Windows variable" +elif sys.platform.startswith('linux'): + OS_TYPE = "Linux" + OS_SEP = "/" + UNUSED_LINUX_VAR = "Unused Linux variable" +else: + OS_TYPE = "Other" + OS_SEP = "/" + UNUSED_OTHER_VAR = "Unused other variable" + +# While loop with variable definitions +counter = 0 +while counter < 5: + LOOP_RESULT = "Iteration " + str(counter) + UNUSED_LOOP_VAR = "Unused loop " + str(counter) + counter += 1 + +def get_platform_info(): + return "OS: " + OS_TYPE + ", Separator: " + OS_SEP + +def get_loop_result(): + return LOOP_RESULT + +def unused_function(): + result = "" + if sys.platform.startswith('win'): + result = UNUSED_WIN_VAR + elif sys.platform.startswith('linux'): + result = UNUSED_LINUX_VAR + else: + result = UNUSED_OTHER_VAR + return result +""" + + expected = """ +import sys +import platform + +# Top-level if-else block defining variables +if sys.platform.startswith('win'): + OS_TYPE = "Windows" + OS_SEP = "" +elif sys.platform.startswith('linux'): + OS_TYPE = "Linux" + OS_SEP = "/" +else: + OS_TYPE = "Other" + OS_SEP = "/" + +# While loop with variable definitions +counter = 0 +while counter < 5: + LOOP_RESULT = "Iteration " + str(counter) + counter += 1 + +def get_platform_info(): + return "OS: " + OS_TYPE + ", Separator: " + OS_SEP + +def get_loop_result(): + return LOOP_RESULT + +def unused_function(): + result = "" + if sys.platform.startswith('win'): + result = UNUSED_WIN_VAR + elif sys.platform.startswith('linux'): + result = UNUSED_LINUX_VAR + else: + result = UNUSED_OTHER_VAR + return result +""" + + qualified_functions = {"get_platform_info", "get_loop_result"} + result = remove_unused_definitions_by_function_names(code, qualified_functions) + assert result.strip() == expected.strip() \ No newline at end of file From 7cda6aafa7833b804601346bacb617b3b8663ed2 Mon Sep 17 00:00:00 2001 From: Alvin Ryanputra Date: Thu, 17 Apr 2025 18:44:15 -0400 Subject: [PATCH 2/6] tests for code context extractor --- tests/test_code_context_extractor.py | 405 +++++++++++++++++++++++- tests/test_remove_unused_definitions.py | 10 +- 2 files changed, 413 insertions(+), 2 deletions(-) diff --git a/tests/test_code_context_extractor.py b/tests/test_code_context_extractor.py index da3f8350a..6bf466af5 100644 --- a/tests/test_code_context_extractor.py +++ b/tests/test_code_context_extractor.py @@ -1311,6 +1311,409 @@ def test_direct_module_import() -> None: ending_line=None, ) + code_ctx = get_code_optimization_context(function_to_optimize, project_root) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code - print(read_only_context.strip()) \ No newline at end of file + + expected_read_only_context = """ +```python:utils.py +from transform_utils import DataTransformer + +class DataProcessor: + \"\"\"A class for processing data.\"\"\" + + number = 1 + + def __repr__(self) -> str: + \"\"\"Return a string representation of the DataProcessor.\"\"\" + return f"DataProcessor(default_prefix={self.default_prefix!r})" + + def process_data(self, raw_data: str) -> str: + \"\"\"Process raw data by converting it to uppercase.\"\"\" + return raw_data.upper() + + def transform_data(self, data: str) -> str: + \"\"\"Transform the processed data\"\"\" + return DataTransformer().transform(data) +```""" + expected_read_write_context = """ +import requests +from globals import API_URL +from utils import DataProcessor +import code_to_optimize.code_directories.retriever.main + +def fetch_and_transform_data(): + # Use the global variable for the request + response = requests.get(API_URL) + + raw_data = response.text + + # Use code from another file (utils.py) + processor = DataProcessor() + processed = processor.process_data(raw_data) + transformed = processor.transform_data(processed) + + return transformed + + + +def function_to_optimize(): + return code_to_optimize.code_directories.retriever.main.fetch_and_transform_data() +""" + assert read_write_context.strip() == expected_read_write_context.strip() + assert read_only_context.strip() == expected_read_only_context.strip() + +def test_comfy_module_import() -> None: + code = ''' +import model_management + +class HunyuanVideoClipModel(torch.nn.Module): + def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}): + super().__init__() + dtype_llama = model_management.pick_weight_dtype(dtype_llama, dtype, device) + self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options) + self.llama = LLAMAModel(device=device, dtype=dtype_llama, model_options=model_options) + self.dtypes = set([dtype, dtype_llama]) + + def set_clip_options(self, options): + self.clip_l.set_clip_options(options) + self.llama.set_clip_options(options) + + def reset_clip_options(self): + self.clip_l.reset_clip_options() + self.llama.reset_clip_options() + + def encode_token_weights(self, token_weight_pairs): + token_weight_pairs_l = token_weight_pairs["l"] + token_weight_pairs_llama = token_weight_pairs["llama"] + + llama_out, llama_pooled, llama_extra_out = self.llama.encode_token_weights(token_weight_pairs_llama) + + template_end = 0 + extra_template_end = 0 + extra_sizes = 0 + user_end = 9999999999999 + images = [] + + tok_pairs = token_weight_pairs_llama[0] + for i, v in enumerate(tok_pairs): + elem = v[0] + if not torch.is_tensor(elem): + if isinstance(elem, numbers.Integral): + if elem == 128006: + if tok_pairs[i + 1][0] == 882: + if tok_pairs[i + 2][0] == 128007: + template_end = i + 2 + user_end = -1 + if elem == 128009 and user_end == -1: + user_end = i + 1 + else: + if elem.get("original_type") == "image": + elem_size = elem.get("data").shape[0] + if template_end > 0: + if user_end == -1: + extra_template_end += elem_size - 1 + else: + image_start = i + extra_sizes + image_end = i + elem_size + extra_sizes + images.append((image_start, image_end, elem.get("image_interleave", 1))) + extra_sizes += elem_size - 1 + + if llama_out.shape[1] > (template_end + 2): + if tok_pairs[template_end + 1][0] == 271: + template_end += 2 + llama_output = llama_out[:, template_end + extra_sizes:user_end + extra_sizes + extra_template_end] + llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end + extra_sizes:user_end + extra_sizes + extra_template_end] + if llama_extra_out["attention_mask"].sum() == torch.numel(llama_extra_out["attention_mask"]): + llama_extra_out.pop("attention_mask") # attention mask is useless if no masked elements + + if len(images) > 0: + out = [] + for i in images: + out.append(llama_out[:, i[0]: i[1]: i[2]]) + llama_output = torch.cat(out + [llama_output], dim=1) + + l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l) + return llama_output, l_pooled, llama_extra_out + + def load_sd(self, sd): + if "text_model.encoder.layers.1.mlp.fc1.weight" in sd: + return self.clip_l.load_sd(sd) + else: + return self.llama.load_sd(sd) +''' + model_management_code = ''' +import psutil +import logging +from enum import Enum +from comfy.cli_args import args, PerformanceFeature +import torch +import sys +import platform +import weakref +import gc + +class VRAMState(Enum): + DISABLED = 0 #No vram present: no need to move models to vram + NO_VRAM = 1 #Very low vram: enable all the options to save vram + LOW_VRAM = 2 + NORMAL_VRAM = 3 + HIGH_VRAM = 4 + SHARED = 5 #No dedicated vram: memory shared between CPU and GPU but models still need to be moved between both. + +class CPUState(Enum): + GPU = 0 + CPU = 1 + MPS = 2 + +# Determine VRAM State +vram_state = VRAMState.NORMAL_VRAM +set_vram_to = VRAMState.NORMAL_VRAM +cpu_state = CPUState.GPU + +total_vram = 0 + +def get_supported_float8_types(): + float8_types = [] + try: + float8_types.append(torch.float8_e4m3fn) + except: + pass + try: + float8_types.append(torch.float8_e4m3fnuz) + except: + pass + try: + float8_types.append(torch.float8_e5m2) + except: + pass + try: + float8_types.append(torch.float8_e5m2fnuz) + except: + pass + try: + float8_types.append(torch.float8_e8m0fnu) + except: + pass + return float8_types + +FLOAT8_TYPES = get_supported_float8_types() + +xpu_available = False +torch_version = "" +try: + torch_version = torch.version.__version__ + temp = torch_version.split(".") + torch_version_numeric = (int(temp[0]), int(temp[1])) + xpu_available = (torch_version_numeric[0] < 2 or (torch_version_numeric[0] == 2 and torch_version_numeric[1] <= 4)) and torch.xpu.is_available() +except: + pass + +lowvram_available = True +if args.deterministic: + logging.info("Using deterministic algorithms for pytorch") + torch.use_deterministic_algorithms(True, warn_only=True) + +directml_enabled = False +if args.directml is not None: + import torch_directml + directml_enabled = True + device_index = args.directml + if device_index < 0: + directml_device = torch_directml.device() + else: + directml_device = torch_directml.device(device_index) + logging.info("Using directml with device: {}".format(torch_directml.device_name(device_index))) + # torch_directml.disable_tiled_resources(True) + lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default. + +try: + import intel_extension_for_pytorch as ipex + _ = torch.xpu.device_count() + xpu_available = xpu_available or torch.xpu.is_available() +except: + xpu_available = xpu_available or (hasattr(torch, "xpu") and torch.xpu.is_available()) + +try: + if torch.backends.mps.is_available(): + cpu_state = CPUState.MPS + import torch.mps +except: + pass + +try: + import torch_npu # noqa: F401 + _ = torch.npu.device_count() + npu_available = torch.npu.is_available() +except: + npu_available = False + +try: + import torch_mlu # noqa: F401 + _ = torch.mlu.device_count() + mlu_available = torch.mlu.is_available() +except: + mlu_available = False + +if args.cpu: + cpu_state = CPUState.CPU + +def supports_cast(device, dtype): #TODO + if dtype == torch.float32: + return True + if dtype == torch.float16: + return True + if directml_enabled: #TODO: test this + return False + if dtype == torch.bfloat16: + return True + if is_device_mps(device): + return False + if dtype == torch.float8_e4m3fn: + return True + if dtype == torch.float8_e5m2: + return True + return False + +def pick_weight_dtype(dtype, fallback_dtype, device=None): + if dtype is None: + dtype = fallback_dtype + elif dtype_size(dtype) > dtype_size(fallback_dtype): + dtype = fallback_dtype + + if not supports_cast(device, dtype): + dtype = fallback_dtype + + return dtype + + +''' + + # Create a temporary directory instead of a single file + with tempfile.TemporaryDirectory() as temp_dir: + # Create a package structure + package_dir = Path(temp_dir) / "package" + package_dir.mkdir() + + # Create the __init__.py file to make it a proper package + with open(package_dir / "__init__.py", "w") as init_file: + init_file.write("") + + # Write the model_management.py file + with open(package_dir / "model_management.py", "w") as model_file: + model_file.write(model_management_code) + model_file.flush() + + # Write the main code file that imports from model_management + main_file_path = package_dir / "main_module.py" + with open(main_file_path, "w") as main_file: + main_file.write(code) + main_file.flush() + + # Now set up the optimizer with the path to the main file + file_path = main_file_path.resolve() + opt = Optimizer( + Namespace( + project_root=package_dir.resolve(), + disable_telemetry=True, + tests_root="tests", + test_framework="pytest", + pytest_cmd="pytest", + experiment_id=None, + test_project_root=Path().resolve(), + ) + ) + + function_to_optimize = FunctionToOptimize( + function_name="encode_token_weights", + file_path=file_path, + parents=[FunctionParent(name="HunyuanVideoClipModel", type="ClassDef")], + starting_line=None, + ending_line=None, + ) + + code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) + read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code + expected_read_write_context = """ +import model_management + +class HunyuanVideoClipModel(torch.nn.Module): + def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}): + super().__init__() + dtype_llama = model_management.pick_weight_dtype(dtype_llama, dtype, device) + self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options) + self.llama = LLAMAModel(device=device, dtype=dtype_llama, model_options=model_options) + self.dtypes = set([dtype, dtype_llama]) + + def encode_token_weights(self, token_weight_pairs): + token_weight_pairs_l = token_weight_pairs["l"] + token_weight_pairs_llama = token_weight_pairs["llama"] + + llama_out, llama_pooled, llama_extra_out = self.llama.encode_token_weights(token_weight_pairs_llama) + + template_end = 0 + extra_template_end = 0 + extra_sizes = 0 + user_end = 9999999999999 + images = [] + + tok_pairs = token_weight_pairs_llama[0] + for i, v in enumerate(tok_pairs): + elem = v[0] + if not torch.is_tensor(elem): + if isinstance(elem, numbers.Integral): + if elem == 128006: + if tok_pairs[i + 1][0] == 882: + if tok_pairs[i + 2][0] == 128007: + template_end = i + 2 + user_end = -1 + if elem == 128009 and user_end == -1: + user_end = i + 1 + else: + if elem.get("original_type") == "image": + elem_size = elem.get("data").shape[0] + if template_end > 0: + if user_end == -1: + extra_template_end += elem_size - 1 + else: + image_start = i + extra_sizes + image_end = i + elem_size + extra_sizes + images.append((image_start, image_end, elem.get("image_interleave", 1))) + extra_sizes += elem_size - 1 + + if llama_out.shape[1] > (template_end + 2): + if tok_pairs[template_end + 1][0] == 271: + template_end += 2 + llama_output = llama_out[:, template_end + extra_sizes:user_end + extra_sizes + extra_template_end] + llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end + extra_sizes:user_end + extra_sizes + extra_template_end] + if llama_extra_out["attention_mask"].sum() == torch.numel(llama_extra_out["attention_mask"]): + llama_extra_out.pop("attention_mask") # attention mask is useless if no masked elements + + if len(images) > 0: + out = [] + for i in images: + out.append(llama_out[:, i[0]: i[1]: i[2]]) + llama_output = torch.cat(out + [llama_output], dim=1) + + l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l) + return llama_output, l_pooled, llama_extra_out +""" + expected_read_only_context = """ +```python:model_management.py +# Determine VRAM State + + +def pick_weight_dtype(dtype, fallback_dtype, device=None): + if dtype is None: + dtype = fallback_dtype + elif dtype_size(dtype) > dtype_size(fallback_dtype): + dtype = fallback_dtype + + if not supports_cast(device, dtype): + dtype = fallback_dtype + + return dtype +``` +""" + assert read_write_context.strip() == expected_read_write_context.strip() + assert read_only_context.strip() == expected_read_only_context.strip() \ No newline at end of file diff --git a/tests/test_remove_unused_definitions.py b/tests/test_remove_unused_definitions.py index b2bd19f56..86a57bb6d 100644 --- a/tests/test_remove_unused_definitions.py +++ b/tests/test_remove_unused_definitions.py @@ -1,6 +1,14 @@ +import tempfile +from argparse import Namespace +from pathlib import Path + import libcst as cst +from codeflash.context.code_context_extractor import get_code_optimization_context from codeflash.context.unused_definition_remover import remove_unused_definitions_by_function_names +from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash.models.models import FunctionParent +from codeflash.optimization.optimizer import Optimizer def test_variable_removal_only() -> None: @@ -413,4 +421,4 @@ def unused_function(): qualified_functions = {"get_platform_info", "get_loop_result"} result = remove_unused_definitions_by_function_names(code, qualified_functions) - assert result.strip() == expected.strip() \ No newline at end of file + assert result.strip() == expected.strip() From ed55d88d2be08a5018fde49d664c10af7c79f758 Mon Sep 17 00:00:00 2001 From: Alvin Ryanputra Date: Thu, 17 Apr 2025 18:58:28 -0400 Subject: [PATCH 3/6] updated other tests --- tests/test_code_replacement.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index ea221be78..d3c4d941a 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -748,10 +748,6 @@ def main_method(self): def test_code_replacement10() -> None: get_code_output = """from __future__ import annotations -import os - -os.environ["CODEFLASH_API_KEY"] = "cf-test-key" - class HelperClass: def __init__(self, name): From 27a0e0e247e4449fd5eafc9328e7b384dcdafc8f Mon Sep 17 00:00:00 2001 From: Alvin Ryanputra Date: Thu, 17 Apr 2025 19:00:39 -0400 Subject: [PATCH 4/6] use tuple syntax in isinstance --- codeflash/context/unused_definition_remover.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/codeflash/context/unused_definition_remover.py b/codeflash/context/unused_definition_remover.py index 82a805259..5f3b8b934 100644 --- a/codeflash/context/unused_definition_remover.py +++ b/codeflash/context/unused_definition_remover.py @@ -71,7 +71,7 @@ def collect_top_level_definitions(node: cst.CSTNode, definitions: dict[str, Usag definitions[name] = UsageInfo(name=name) return definitions - if isinstance(node, cst.AnnAssign | cst.AugAssign): + if isinstance(node, (cst.AnnAssign, cst.AugAssign)): if isinstance(node.target, cst.Name): name = node.target.value definitions[name] = UsageInfo(name=name) @@ -318,7 +318,7 @@ def remove_unused_definitions_recursively( """ # Skip import statements - if isinstance(node, cst.Import | cst.ImportFrom): + if isinstance(node, (cst.Import, cst.ImportFrom)): return node, True # Never remove function definitions @@ -394,7 +394,7 @@ def remove_unused_definitions_recursively( return node, True return None, False - if isinstance(node, cst.AnnAssign | cst.AugAssign): + if isinstance(node, (cst.AnnAssign, cst.AugAssign)): names = extract_names_from_targets(node.target) for name in names: if name in definitions and definitions[name].used_by_qualified_function: From addb47dfc8de80a4a0ef6bc1885838d2472e0314 Mon Sep 17 00:00:00 2001 From: Alvin Ryanputra Date: Fri, 18 Apr 2025 15:36:48 -0400 Subject: [PATCH 5/6] syntax fix for isinstance --- codeflash/context/unused_definition_remover.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/codeflash/context/unused_definition_remover.py b/codeflash/context/unused_definition_remover.py index 5f3b8b934..bfcbbaead 100644 --- a/codeflash/context/unused_definition_remover.py +++ b/codeflash/context/unused_definition_remover.py @@ -88,7 +88,7 @@ def collect_top_level_definitions(node: cst.CSTNode, definitions: dict[str, Usag for section in section_names: original_content = getattr(node, section, None) # If section contains a list of nodes - if isinstance(original_content, list | tuple): + if isinstance(original_content, (list, tuple)): for child in original_content: collect_top_level_definitions(child, definitions) # If section contains a single node @@ -411,7 +411,7 @@ def remove_unused_definitions_recursively( for section in section_names: original_content = getattr(node, section, None) - if isinstance(original_content, list | tuple): + if isinstance(original_content, (list, tuple)): new_children = [] section_found_used = False From 0d77ea4a0b7a7d4bfea86a83aaa57484be792fa1 Mon Sep 17 00:00:00 2001 From: Alvin Ryanputra Date: Fri, 18 Apr 2025 20:29:38 -0400 Subject: [PATCH 6/6] reworked tests --- codeflash/context/code_context_extractor.py | 24 +- tests/test_code_context_extractor.py | 665 +++++++++++--------- 2 files changed, 382 insertions(+), 307 deletions(-) diff --git a/codeflash/context/code_context_extractor.py b/codeflash/context/code_context_extractor.py index c989dc708..792a76885 100644 --- a/codeflash/context/code_context_extractor.py +++ b/codeflash/context/code_context_extractor.py @@ -183,14 +183,16 @@ def extract_code_string_context_from_files( helpers_of_helpers_qualified_names = { func.qualified_name for func in helpers_of_helpers.get(file_path, set()) } + code_without_unused_defs = remove_unused_definitions_by_function_names( + original_code, qualified_function_names | helpers_of_helpers_qualified_names + ) code_context = parse_code_and_prune_cst( - original_code, + code_without_unused_defs, code_context_type, qualified_function_names, helpers_of_helpers_qualified_names, remove_docstrings, ) - code_context = remove_unused_definitions_by_function_names(code_context, qualified_function_names | helpers_of_helpers_qualified_names) except ValueError as e: logger.debug(f"Error while getting read-only code: {e}") continue @@ -215,10 +217,10 @@ def extract_code_string_context_from_files( continue try: qualified_helper_function_names = {func.qualified_name for func in helper_function_sources} + code_without_unused_defs = remove_unused_definitions_by_function_names(original_code, qualified_helper_function_names) code_context = parse_code_and_prune_cst( - original_code, code_context_type, set(), qualified_helper_function_names, remove_docstrings + code_without_unused_defs, code_context_type, set(), qualified_helper_function_names, remove_docstrings ) - code_context = remove_unused_definitions_by_function_names(code_context, qualified_helper_function_names) except ValueError as e: logger.debug(f"Error while getting read-only code: {e}") continue @@ -285,16 +287,16 @@ def extract_code_markdown_context_from_files( helpers_of_helpers_qualified_names = { func.qualified_name for func in helpers_of_helpers.get(file_path, set()) } + code_without_unused_defs = remove_unused_definitions_by_function_names( + original_code, qualified_function_names | helpers_of_helpers_qualified_names + ) code_context = parse_code_and_prune_cst( - original_code, + code_without_unused_defs, code_context_type, qualified_function_names, helpers_of_helpers_qualified_names, remove_docstrings, ) - code_context = remove_unused_definitions_by_function_names( - code_context, qualified_function_names | helpers_of_helpers_qualified_names - ) except ValueError as e: logger.debug(f"Error while getting read-only code: {e}") @@ -323,11 +325,9 @@ def extract_code_markdown_context_from_files( continue try: qualified_helper_function_names = {func.qualified_name for func in helper_function_sources} + code_without_unused_defs = remove_unused_definitions_by_function_names(original_code, qualified_helper_function_names) code_context = parse_code_and_prune_cst( - original_code, code_context_type, set(), qualified_helper_function_names, remove_docstrings - ) - code_context = remove_unused_definitions_by_function_names( - code_context, qualified_helper_function_names + code_without_unused_defs, code_context_type, set(), qualified_helper_function_names, remove_docstrings ) except ValueError as e: logger.debug(f"Error while getting read-only code: {e}") diff --git a/tests/test_code_context_extractor.py b/tests/test_code_context_extractor.py index 6bf466af5..90356ac10 100644 --- a/tests/test_code_context_extractor.py +++ b/tests/test_code_context_extractor.py @@ -1363,254 +1363,354 @@ def function_to_optimize(): assert read_write_context.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() -def test_comfy_module_import() -> None: - code = ''' -import model_management - -class HunyuanVideoClipModel(torch.nn.Module): - def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}): - super().__init__() - dtype_llama = model_management.pick_weight_dtype(dtype_llama, dtype, device) - self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options) - self.llama = LLAMAModel(device=device, dtype=dtype_llama, model_options=model_options) - self.dtypes = set([dtype, dtype_llama]) - - def set_clip_options(self, options): - self.clip_l.set_clip_options(options) - self.llama.set_clip_options(options) - - def reset_clip_options(self): - self.clip_l.reset_clip_options() - self.llama.reset_clip_options() - - def encode_token_weights(self, token_weight_pairs): - token_weight_pairs_l = token_weight_pairs["l"] - token_weight_pairs_llama = token_weight_pairs["llama"] - - llama_out, llama_pooled, llama_extra_out = self.llama.encode_token_weights(token_weight_pairs_llama) - - template_end = 0 - extra_template_end = 0 - extra_sizes = 0 - user_end = 9999999999999 - images = [] - - tok_pairs = token_weight_pairs_llama[0] - for i, v in enumerate(tok_pairs): - elem = v[0] - if not torch.is_tensor(elem): - if isinstance(elem, numbers.Integral): - if elem == 128006: - if tok_pairs[i + 1][0] == 882: - if tok_pairs[i + 2][0] == 128007: - template_end = i + 2 - user_end = -1 - if elem == 128009 and user_end == -1: - user_end = i + 1 - else: - if elem.get("original_type") == "image": - elem_size = elem.get("data").shape[0] - if template_end > 0: - if user_end == -1: - extra_template_end += elem_size - 1 - else: - image_start = i + extra_sizes - image_end = i + elem_size + extra_sizes - images.append((image_start, image_end, elem.get("image_interleave", 1))) - extra_sizes += elem_size - 1 - - if llama_out.shape[1] > (template_end + 2): - if tok_pairs[template_end + 1][0] == 271: - template_end += 2 - llama_output = llama_out[:, template_end + extra_sizes:user_end + extra_sizes + extra_template_end] - llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end + extra_sizes:user_end + extra_sizes + extra_template_end] - if llama_extra_out["attention_mask"].sum() == torch.numel(llama_extra_out["attention_mask"]): - llama_extra_out.pop("attention_mask") # attention mask is useless if no masked elements - - if len(images) > 0: - out = [] - for i in images: - out.append(llama_out[:, i[0]: i[1]: i[2]]) - llama_output = torch.cat(out + [llama_output], dim=1) - - l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l) - return llama_output, l_pooled, llama_extra_out - - def load_sd(self, sd): - if "text_model.encoder.layers.1.mlp.fc1.weight" in sd: - return self.clip_l.load_sd(sd) +def test_module_import_optimization() -> None: + main_code = ''' +import utility_module + +class Calculator: + def __init__(self, precision="high", fallback_precision=None, mode="standard"): + # This is where we use the imported module + self.precision = utility_module.select_precision(precision, fallback_precision) + self.mode = mode + + # Using variables from the utility module + self.backend = utility_module.CALCULATION_BACKEND + self.system = utility_module.SYSTEM_TYPE + self.default_precision = utility_module.DEFAULT_PRECISION + + def add(self, a, b): + return a + b + + def subtract(self, a, b): + return a - b + + def calculate(self, operation, x, y): + if operation == "add": + return self.add(x, y) + elif operation == "subtract": + return self.subtract(x, y) else: - return self.llama.load_sd(sd) + return None ''' - model_management_code = ''' -import psutil -import logging -from enum import Enum -from comfy.cli_args import args, PerformanceFeature -import torch + + utility_module_code = ''' import sys import platform -import weakref -import gc - -class VRAMState(Enum): - DISABLED = 0 #No vram present: no need to move models to vram - NO_VRAM = 1 #Very low vram: enable all the options to save vram - LOW_VRAM = 2 - NORMAL_VRAM = 3 - HIGH_VRAM = 4 - SHARED = 5 #No dedicated vram: memory shared between CPU and GPU but models still need to be moved between both. - -class CPUState(Enum): - GPU = 0 - CPU = 1 - MPS = 2 - -# Determine VRAM State -vram_state = VRAMState.NORMAL_VRAM -set_vram_to = VRAMState.NORMAL_VRAM -cpu_state = CPUState.GPU - -total_vram = 0 - -def get_supported_float8_types(): - float8_types = [] - try: - float8_types.append(torch.float8_e4m3fn) - except: - pass - try: - float8_types.append(torch.float8_e4m3fnuz) - except: - pass - try: - float8_types.append(torch.float8_e5m2) - except: - pass - try: - float8_types.append(torch.float8_e5m2fnuz) - except: - pass - try: - float8_types.append(torch.float8_e8m0fnu) - except: - pass - return float8_types +import logging -FLOAT8_TYPES = get_supported_float8_types() +DEFAULT_PRECISION = "medium" +DEFAULT_MODE = "standard" -xpu_available = False -torch_version = "" +# Try-except block with variable definitions try: - torch_version = torch.version.__version__ - temp = torch_version.split(".") - torch_version_numeric = (int(temp[0]), int(temp[1])) - xpu_available = (torch_version_numeric[0] < 2 or (torch_version_numeric[0] == 2 and torch_version_numeric[1] <= 4)) and torch.xpu.is_available() -except: - pass - -lowvram_available = True -if args.deterministic: - logging.info("Using deterministic algorithms for pytorch") - torch.use_deterministic_algorithms(True, warn_only=True) - -directml_enabled = False -if args.directml is not None: - import torch_directml - directml_enabled = True - device_index = args.directml - if device_index < 0: - directml_device = torch_directml.device() + import numpy as np + # Used variable in try block + CALCULATION_BACKEND = "numpy" + # Unused variable in try block + VECTOR_DIMENSIONS = 3 +except ImportError: + # Used variable in except block + CALCULATION_BACKEND = "python" + # Unused variable in except block + FALLBACK_WARNING = "NumPy not available, using slower Python implementation" + +# Nested if-else with variable definitions +if sys.platform.startswith('win'): + # Used variable in outer if + SYSTEM_TYPE = "windows" + if platform.architecture()[0] == '64bit': + # Unused variable in nested if + MEMORY_MODEL = "x64" else: - directml_device = torch_directml.device(device_index) - logging.info("Using directml with device: {}".format(torch_directml.device_name(device_index))) - # torch_directml.disable_tiled_resources(True) - lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default. + # Unused variable in nested else + MEMORY_MODEL = "x86" +elif sys.platform.startswith('linux'): + # Used variable in outer elif + SYSTEM_TYPE = "linux" + # Unused variable in outer elif + KERNEL_VERSION = platform.release() +else: + # Used variable in outer else + SYSTEM_TYPE = "other" + # Unused variable in outer else + UNKNOWN_SYSTEM_MSG = "Running on an unrecognized platform" + +# Function that will be used in the main code +def select_precision(precision, fallback_precision): + if precision is None: + return fallback_precision or DEFAULT_PRECISION + + # Using the variables defined above + if CALCULATION_BACKEND == "numpy": + # Higher precision available with NumPy + precision_options = ["low", "medium", "high", "ultra"] + else: + # Limited precision without NumPy + precision_options = ["low", "medium", "high"] + + if isinstance(precision, str): + if precision.lower() not in precision_options: + if fallback_precision: + return fallback_precision + else: + return DEFAULT_PRECISION + return precision.lower() + else: + return DEFAULT_PRECISION + +# Function that won't be used +def get_system_details(): + return { + "system": SYSTEM_TYPE, + "backend": CALCULATION_BACKEND, + "default_precision": DEFAULT_PRECISION, + "python_version": sys.version + } +''' -try: - import intel_extension_for_pytorch as ipex - _ = torch.xpu.device_count() - xpu_available = xpu_available or torch.xpu.is_available() -except: - xpu_available = xpu_available or (hasattr(torch, "xpu") and torch.xpu.is_available()) + # Create a temporary directory for the test + with tempfile.TemporaryDirectory() as temp_dir: + # Set up the package structure + package_dir = Path(temp_dir) / "package" + package_dir.mkdir() -try: - if torch.backends.mps.is_available(): - cpu_state = CPUState.MPS - import torch.mps -except: - pass + # Create the __init__.py file + with open(package_dir / "__init__.py", "w") as init_file: + init_file.write("") -try: - import torch_npu # noqa: F401 - _ = torch.npu.device_count() - npu_available = torch.npu.is_available() -except: - npu_available = False + # Write the utility_module.py file + with open(package_dir / "utility_module.py", "w") as utility_file: + utility_file.write(utility_module_code) + utility_file.flush() + + # Write the main code file + main_file_path = package_dir / "main_module.py" + with open(main_file_path, "w") as main_file: + main_file.write(main_code) + main_file.flush() + + # Set up the optimizer + file_path = main_file_path.resolve() + opt = Optimizer( + Namespace( + project_root=package_dir.resolve(), + disable_telemetry=True, + tests_root="tests", + test_framework="pytest", + pytest_cmd="pytest", + experiment_id=None, + test_project_root=Path().resolve(), + ) + ) + + # Define the function to optimize + function_to_optimize = FunctionToOptimize( + function_name="calculate", + file_path=file_path, + parents=[FunctionParent(name="Calculator", type="ClassDef")], + starting_line=None, + ending_line=None, + ) + # Get the code optimization context + code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) + read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code + # The expected contexts + expected_read_write_context = """ +import utility_module + +class Calculator: + def __init__(self, precision="high", fallback_precision=None, mode="standard"): + # This is where we use the imported module + self.precision = utility_module.select_precision(precision, fallback_precision) + self.mode = mode + + # Using variables from the utility module + self.backend = utility_module.CALCULATION_BACKEND + self.system = utility_module.SYSTEM_TYPE + self.default_precision = utility_module.DEFAULT_PRECISION + + def add(self, a, b): + return a + b + + def subtract(self, a, b): + return a - b + + def calculate(self, operation, x, y): + if operation == "add": + return self.add(x, y) + elif operation == "subtract": + return self.subtract(x, y) + else: + return None +""" + expected_read_only_context = """ +```python:utility_module.py +DEFAULT_PRECISION = "medium" + +# Try-except block with variable definitions try: - import torch_mlu # noqa: F401 - _ = torch.mlu.device_count() - mlu_available = torch.mlu.is_available() -except: - mlu_available = False - -if args.cpu: - cpu_state = CPUState.CPU - -def supports_cast(device, dtype): #TODO - if dtype == torch.float32: - return True - if dtype == torch.float16: - return True - if directml_enabled: #TODO: test this - return False - if dtype == torch.bfloat16: - return True - if is_device_mps(device): - return False - if dtype == torch.float8_e4m3fn: - return True - if dtype == torch.float8_e5m2: - return True - return False - -def pick_weight_dtype(dtype, fallback_dtype, device=None): - if dtype is None: - dtype = fallback_dtype - elif dtype_size(dtype) > dtype_size(fallback_dtype): - dtype = fallback_dtype - - if not supports_cast(device, dtype): - dtype = fallback_dtype - - return dtype + # Used variable in try block + CALCULATION_BACKEND = "numpy" +except ImportError: + # Used variable in except block + CALCULATION_BACKEND = "python" + +# Function that will be used in the main code +def select_precision(precision, fallback_precision): + if precision is None: + return fallback_precision or DEFAULT_PRECISION + + # Using the variables defined above + if CALCULATION_BACKEND == "numpy": + # Higher precision available with NumPy + precision_options = ["low", "medium", "high", "ultra"] + else: + # Limited precision without NumPy + precision_options = ["low", "medium", "high"] + + if isinstance(precision, str): + if precision.lower() not in precision_options: + if fallback_precision: + return fallback_precision + else: + return DEFAULT_PRECISION + return precision.lower() + else: + return DEFAULT_PRECISION +``` +""" + # Verify the contexts match the expected values + assert read_write_context.strip() == expected_read_write_context.strip() + assert read_only_context.strip() == expected_read_only_context.strip() + +def test_module_import_init_fto() -> None: + main_code = ''' +import utility_module +class Calculator: + def __init__(self, precision="high", fallback_precision=None, mode="standard"): + # This is where we use the imported module + self.precision = utility_module.select_precision(precision, fallback_precision) + self.mode = mode + # Using variables from the utility module + self.backend = utility_module.CALCULATION_BACKEND + self.system = utility_module.SYSTEM_TYPE + self.default_precision = utility_module.DEFAULT_PRECISION + + def add(self, a, b): + return a + b + + def subtract(self, a, b): + return a - b + + def calculate(self, operation, x, y): + if operation == "add": + return self.add(x, y) + elif operation == "subtract": + return self.subtract(x, y) + else: + return None ''' - # Create a temporary directory instead of a single file + utility_module_code = ''' +import sys +import platform +import logging + +DEFAULT_PRECISION = "medium" +DEFAULT_MODE = "standard" + +# Try-except block with variable definitions +try: + import numpy as np + # Used variable in try block + CALCULATION_BACKEND = "numpy" + # Unused variable in try block + VECTOR_DIMENSIONS = 3 +except ImportError: + # Used variable in except block + CALCULATION_BACKEND = "python" + # Unused variable in except block + FALLBACK_WARNING = "NumPy not available, using slower Python implementation" + +# Nested if-else with variable definitions +if sys.platform.startswith('win'): + # Used variable in outer if + SYSTEM_TYPE = "windows" + if platform.architecture()[0] == '64bit': + # Unused variable in nested if + MEMORY_MODEL = "x64" + else: + # Unused variable in nested else + MEMORY_MODEL = "x86" +elif sys.platform.startswith('linux'): + # Used variable in outer elif + SYSTEM_TYPE = "linux" + # Unused variable in outer elif + KERNEL_VERSION = platform.release() +else: + # Used variable in outer else + SYSTEM_TYPE = "other" + # Unused variable in outer else + UNKNOWN_SYSTEM_MSG = "Running on an unrecognized platform" + +# Function that will be used in the main code +def select_precision(precision, fallback_precision): + if precision is None: + return fallback_precision or DEFAULT_PRECISION + + # Using the variables defined above + if CALCULATION_BACKEND == "numpy": + # Higher precision available with NumPy + precision_options = ["low", "medium", "high", "ultra"] + else: + # Limited precision without NumPy + precision_options = ["low", "medium", "high"] + + if isinstance(precision, str): + if precision.lower() not in precision_options: + if fallback_precision: + return fallback_precision + else: + return DEFAULT_PRECISION + return precision.lower() + else: + return DEFAULT_PRECISION + +# Function that won't be used +def get_system_details(): + return { + "system": SYSTEM_TYPE, + "backend": CALCULATION_BACKEND, + "default_precision": DEFAULT_PRECISION, + "python_version": sys.version + } +''' + + # Create a temporary directory for the test with tempfile.TemporaryDirectory() as temp_dir: - # Create a package structure + # Set up the package structure package_dir = Path(temp_dir) / "package" package_dir.mkdir() - # Create the __init__.py file to make it a proper package + # Create the __init__.py file with open(package_dir / "__init__.py", "w") as init_file: init_file.write("") - # Write the model_management.py file - with open(package_dir / "model_management.py", "w") as model_file: - model_file.write(model_management_code) - model_file.flush() + # Write the utility_module.py file + with open(package_dir / "utility_module.py", "w") as utility_file: + utility_file.write(utility_module_code) + utility_file.flush() - # Write the main code file that imports from model_management + # Write the main code file main_file_path = package_dir / "main_module.py" with open(main_file_path, "w") as main_file: - main_file.write(code) + main_file.write(main_code) main_file.flush() - # Now set up the optimizer with the path to the main file + # Set up the optimizer file_path = main_file_path.resolve() opt = Optimizer( Namespace( @@ -1624,95 +1724,70 @@ def pick_weight_dtype(dtype, fallback_dtype, device=None): ) ) + # Define the function to optimize function_to_optimize = FunctionToOptimize( - function_name="encode_token_weights", + function_name="__init__", file_path=file_path, - parents=[FunctionParent(name="HunyuanVideoClipModel", type="ClassDef")], + parents=[FunctionParent(name="Calculator", type="ClassDef")], starting_line=None, ending_line=None, ) + # Get the code optimization context code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code + # The expected contexts expected_read_write_context = """ -import model_management - -class HunyuanVideoClipModel(torch.nn.Module): - def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}): - super().__init__() - dtype_llama = model_management.pick_weight_dtype(dtype_llama, dtype, device) - self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options) - self.llama = LLAMAModel(device=device, dtype=dtype_llama, model_options=model_options) - self.dtypes = set([dtype, dtype_llama]) - - def encode_token_weights(self, token_weight_pairs): - token_weight_pairs_l = token_weight_pairs["l"] - token_weight_pairs_llama = token_weight_pairs["llama"] - - llama_out, llama_pooled, llama_extra_out = self.llama.encode_token_weights(token_weight_pairs_llama) - - template_end = 0 - extra_template_end = 0 - extra_sizes = 0 - user_end = 9999999999999 - images = [] - - tok_pairs = token_weight_pairs_llama[0] - for i, v in enumerate(tok_pairs): - elem = v[0] - if not torch.is_tensor(elem): - if isinstance(elem, numbers.Integral): - if elem == 128006: - if tok_pairs[i + 1][0] == 882: - if tok_pairs[i + 2][0] == 128007: - template_end = i + 2 - user_end = -1 - if elem == 128009 and user_end == -1: - user_end = i + 1 - else: - if elem.get("original_type") == "image": - elem_size = elem.get("data").shape[0] - if template_end > 0: - if user_end == -1: - extra_template_end += elem_size - 1 - else: - image_start = i + extra_sizes - image_end = i + elem_size + extra_sizes - images.append((image_start, image_end, elem.get("image_interleave", 1))) - extra_sizes += elem_size - 1 - - if llama_out.shape[1] > (template_end + 2): - if tok_pairs[template_end + 1][0] == 271: - template_end += 2 - llama_output = llama_out[:, template_end + extra_sizes:user_end + extra_sizes + extra_template_end] - llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end + extra_sizes:user_end + extra_sizes + extra_template_end] - if llama_extra_out["attention_mask"].sum() == torch.numel(llama_extra_out["attention_mask"]): - llama_extra_out.pop("attention_mask") # attention mask is useless if no masked elements - - if len(images) > 0: - out = [] - for i in images: - out.append(llama_out[:, i[0]: i[1]: i[2]]) - llama_output = torch.cat(out + [llama_output], dim=1) - - l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l) - return llama_output, l_pooled, llama_extra_out -""" - expected_read_only_context = """ -```python:model_management.py -# Determine VRAM State +# Function that will be used in the main code + +import utility_module + +def select_precision(precision, fallback_precision): + if precision is None: + return fallback_precision or DEFAULT_PRECISION + + # Using the variables defined above + if CALCULATION_BACKEND == "numpy": + # Higher precision available with NumPy + precision_options = ["low", "medium", "high", "ultra"] + else: + # Limited precision without NumPy + precision_options = ["low", "medium", "high"] + + if isinstance(precision, str): + if precision.lower() not in precision_options: + if fallback_precision: + return fallback_precision + else: + return DEFAULT_PRECISION + return precision.lower() + else: + return DEFAULT_PRECISION -def pick_weight_dtype(dtype, fallback_dtype, device=None): - if dtype is None: - dtype = fallback_dtype - elif dtype_size(dtype) > dtype_size(fallback_dtype): - dtype = fallback_dtype - if not supports_cast(device, dtype): - dtype = fallback_dtype +class Calculator: + def __init__(self, precision="high", fallback_precision=None, mode="standard"): + # This is where we use the imported module + self.precision = utility_module.select_precision(precision, fallback_precision) + self.mode = mode - return dtype + # Using variables from the utility module + self.backend = utility_module.CALCULATION_BACKEND + self.system = utility_module.SYSTEM_TYPE + self.default_precision = utility_module.DEFAULT_PRECISION +""" + expected_read_only_context = """ +```python:utility_module.py +DEFAULT_PRECISION = "medium" + +# Try-except block with variable definitions +try: + # Used variable in try block + CALCULATION_BACKEND = "numpy" +except ImportError: + # Used variable in except block + CALCULATION_BACKEND = "python" ``` """ assert read_write_context.strip() == expected_read_write_context.strip()