|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +from collections import defaultdict |
3 | 4 | from dataclasses import dataclass, field |
4 | 5 |
|
5 | 6 | import libcst as cst |
@@ -255,24 +256,24 @@ class QualifiedFunctionUsageMarker: |
255 | 256 | def __init__(self, definitions: dict[str, UsageInfo], qualified_function_names: set[str]) -> None: |
256 | 257 | self.definitions = definitions |
257 | 258 | self.qualified_function_names = qualified_function_names |
| 259 | + self.class_dunder_methods = self._preprocess_definitions() |
258 | 260 | self.expanded_qualified_functions = self._expand_qualified_functions() |
259 | 261 |
|
260 | 262 | def _expand_qualified_functions(self) -> set[str]: |
261 | 263 | """Expand the qualified function names to include related methods.""" |
262 | 264 | expanded = set(self.qualified_function_names) |
263 | 265 |
|
264 | 266 | # Find class methods and add their containing classes and dunder methods |
265 | | - for qualified_name in list(self.qualified_function_names): |
| 267 | + for qualified_name in self.qualified_function_names: |
266 | 268 | if "." in qualified_name: |
267 | 269 | class_name, method_name = qualified_name.split(".", 1) |
268 | 270 |
|
269 | 271 | # Add the class itself |
270 | 272 | expanded.add(class_name) |
271 | 273 |
|
272 | 274 | # Add all dunder methods of the class |
273 | | - for name in self.definitions: |
274 | | - if name.startswith(f"{class_name}.__") and name.endswith("__"): |
275 | | - expanded.add(name) |
| 275 | + if class_name in self.class_dunder_methods: |
| 276 | + expanded.update(self.class_dunder_methods[class_name]) |
276 | 277 |
|
277 | 278 | return expanded |
278 | 279 |
|
@@ -301,9 +302,21 @@ def mark_as_used_recursively(self, name: str) -> None: |
301 | 302 | for dep in self.definitions[name].dependencies: |
302 | 303 | self.mark_as_used_recursively(dep) |
303 | 304 |
|
| 305 | + def _preprocess_definitions(self) -> dict[str, set[str]]: |
| 306 | + """Preprocess definitions to find dunder methods for each class.""" |
| 307 | + class_dunder_methods = defaultdict(set) |
| 308 | + |
| 309 | + for name in self.definitions: |
| 310 | + if name.count(".") == 1: |
| 311 | + class_name, method_name = name.split(".", 1) |
| 312 | + if method_name.startswith("__") and method_name.endswith("__"): |
| 313 | + class_dunder_methods[class_name].add(name) |
| 314 | + |
| 315 | + return class_dunder_methods |
| 316 | + |
304 | 317 |
|
305 | 318 | def remove_unused_definitions_recursively( |
306 | | - node: cst.CSTNode, definitions: dict[str, UsageInfo] |
| 319 | + node: cst.CSTNode, definitions: dict[str, UsageInfo] |
307 | 320 | ) -> tuple[cst.CSTNode | None, bool]: |
308 | 321 | """Recursively filter the node to remove unused definitions. |
309 | 322 |
|
@@ -358,7 +371,10 @@ def remove_unused_definitions_recursively( |
358 | 371 | names = extract_names_from_targets(target.target) |
359 | 372 | for name in names: |
360 | 373 | class_var_name = f"{class_name}.{name}" |
361 | | - if class_var_name in definitions and definitions[class_var_name].used_by_qualified_function: |
| 374 | + if ( |
| 375 | + class_var_name in definitions |
| 376 | + and definitions[class_var_name].used_by_qualified_function |
| 377 | + ): |
362 | 378 | var_used = True |
363 | 379 | method_or_var_used = True |
364 | 380 | break |
|
0 commit comments