99
1010import itertools as it
1111from copy import copy
12- from typing import Any , Hashable , Iterable , Literal , Protocol , Union , cast
12+ from typing import TYPE_CHECKING , Any , Hashable , Iterable , Literal , Protocol , cast
1313
1414import networkx as nx
1515import numpy as np
16- from typing_extensions import TypeAlias
16+
17+ if TYPE_CHECKING :
18+ from typing_extensions import TypeAlias
19+
20+ from manim .typing import Point3D
21+
22+ NxGraph : TypeAlias = nx .classes .graph .Graph | nx .classes .digraph .DiGraph
1723
1824from manim .animation .composition import AnimationGroup
1925from manim .animation .creation import Create , Uncreate
2430from manim .mobject .opengl .opengl_mobject import OpenGLMobject
2531from manim .mobject .text .tex_mobject import MathTex
2632from manim .mobject .types .vectorized_mobject import VMobject
27- from manim .typing import Point3D
2833from manim .utils .color import BLACK
2934
30- NxGraph : TypeAlias = Union [nx .classes .graph .Graph , nx .classes .digraph .DiGraph ]
31-
3235
3336class LayoutFunction (Protocol ):
3437 """A protocol for automatic layout functions that compute a layout for a graph to be used in :meth:`~.Graph.change_layout`.
@@ -53,8 +56,8 @@ def custom_layout(
5356 graph: nx.Graph,
5457 scale: float | tuple[float, float, float] = 2,
5558 n: int | None = None,
56- *args: tuple[ Any, ...] ,
57- **kwargs: dict[str, Any] ,
59+ *args: Any,
60+ **kwargs: Any,
5861 ):
5962 nodes = sorted(list(graph))
6063 height = len(nodes) // n
@@ -256,8 +259,8 @@ def __call__(
256259 self ,
257260 graph : NxGraph ,
258261 scale : float | tuple [float , float , float ] = 2 ,
259- * args : tuple [ Any , ...] ,
260- ** kwargs : dict [ str , Any ] ,
262+ * args : Any ,
263+ ** kwargs : Any ,
261264 ) -> dict [Hashable , Point3D ]:
262265 """Given a graph and a scale, return a dictionary of coordinates.
263266
@@ -280,7 +283,7 @@ def _partite_layout(
280283 nx_graph : NxGraph ,
281284 scale : float = 2 ,
282285 partitions : list [list [Hashable ]] | None = None ,
283- ** kwargs : dict [ str , Any ] ,
286+ ** kwargs : Any ,
284287) -> dict [Hashable , Point3D ]:
285288 if partitions is None or len (partitions ) == 0 :
286289 raise ValueError (
@@ -302,7 +305,7 @@ def _partite_layout(
302305 return nx .layout .multipartite_layout (nx_graph , scale = scale , ** kwargs )
303306
304307
305- def _random_layout (nx_graph : NxGraph , scale : float = 2 , ** kwargs : dict [ str , Any ] ):
308+ def _random_layout (nx_graph : NxGraph , scale : float = 2 , ** kwargs : Any ):
306309 # the random layout places coordinates in [0, 1)
307310 # we need to rescale manually afterwards...
308311 auto_layout = nx .layout .random_layout (nx_graph , ** kwargs )
0 commit comments