Skip to content

Commit 0f268e6

Browse files
Follow-up to graph layout cleanup: improvements for tests and typing (#3728)
* suggestions from review on #3434 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent a3adcaa commit 0f268e6

File tree

2 files changed

+16
-17
lines changed

2 files changed

+16
-17
lines changed

manim/mobject/graph.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,17 @@
99

1010
import itertools as it
1111
from copy import copy
12-
from typing import Any, Hashable, Iterable, Literal, Protocol, Union, cast
12+
from typing import TYPE_CHECKING, Any, Hashable, Iterable, Literal, Protocol, cast
1313

1414
import networkx as nx
1515
import numpy as np
16-
from typing_extensions import TypeAlias
16+
17+
if TYPE_CHECKING:
18+
from typing_extensions import TypeAlias
19+
20+
from manim.typing import Point3D
21+
22+
NxGraph: TypeAlias = nx.classes.graph.Graph | nx.classes.digraph.DiGraph
1723

1824
from manim.animation.composition import AnimationGroup
1925
from manim.animation.creation import Create, Uncreate
@@ -24,11 +30,8 @@
2430
from manim.mobject.opengl.opengl_mobject import OpenGLMobject
2531
from manim.mobject.text.tex_mobject import MathTex
2632
from manim.mobject.types.vectorized_mobject import VMobject
27-
from manim.typing import Point3D
2833
from manim.utils.color import BLACK
2934

30-
NxGraph: TypeAlias = Union[nx.classes.graph.Graph, nx.classes.digraph.DiGraph]
31-
3235

3336
class LayoutFunction(Protocol):
3437
"""A protocol for automatic layout functions that compute a layout for a graph to be used in :meth:`~.Graph.change_layout`.
@@ -53,8 +56,8 @@ def custom_layout(
5356
graph: nx.Graph,
5457
scale: float | tuple[float, float, float] = 2,
5558
n: int | None = None,
56-
*args: tuple[Any, ...],
57-
**kwargs: dict[str, Any],
59+
*args: Any,
60+
**kwargs: Any,
5861
):
5962
nodes = sorted(list(graph))
6063
height = len(nodes) // n
@@ -256,8 +259,8 @@ def __call__(
256259
self,
257260
graph: NxGraph,
258261
scale: float | tuple[float, float, float] = 2,
259-
*args: tuple[Any, ...],
260-
**kwargs: dict[str, Any],
262+
*args: Any,
263+
**kwargs: Any,
261264
) -> dict[Hashable, Point3D]:
262265
"""Given a graph and a scale, return a dictionary of coordinates.
263266
@@ -280,7 +283,7 @@ def _partite_layout(
280283
nx_graph: NxGraph,
281284
scale: float = 2,
282285
partitions: list[list[Hashable]] | None = None,
283-
**kwargs: dict[str, Any],
286+
**kwargs: Any,
284287
) -> dict[Hashable, Point3D]:
285288
if partitions is None or len(partitions) == 0:
286289
raise ValueError(
@@ -302,7 +305,7 @@ def _partite_layout(
302305
return nx.layout.multipartite_layout(nx_graph, scale=scale, **kwargs)
303306

304307

305-
def _random_layout(nx_graph: NxGraph, scale: float = 2, **kwargs: dict[str, Any]):
308+
def _random_layout(nx_graph: NxGraph, scale: float = 2, **kwargs: Any):
306309
# the random layout places coordinates in [0, 1)
307310
# we need to rescale manually afterwards...
308311
auto_layout = nx.layout.random_layout(nx_graph, **kwargs)

tests/module/mobject/test_graph.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,7 @@ def test_custom_graph_layout_dict():
116116

117117

118118
def test_graph_layouts():
119-
for layout in (
120-
layout for layout in _layouts if layout != "tree" and layout != "partite"
121-
):
119+
for layout in (layout for layout in _layouts if layout not in ["tree", "partite"]):
122120
G = Graph([1, 2, 3], [(1, 2), (2, 3)], layout=layout)
123121
assert str(G) == "Undirected graph on 3 vertices and 2 edges"
124122

@@ -164,9 +162,7 @@ def layout_func(graph, scale, offset):
164162

165163

166164
def test_graph_change_layout():
167-
for layout in (
168-
layout for layout in _layouts if layout != "tree" and layout != "partite"
169-
):
165+
for layout in (layout for layout in _layouts if layout not in ["tree", "partite"]):
170166
G = Graph([1, 2, 3], [(1, 2), (2, 3)])
171167
G.change_layout(layout=layout)
172168
assert str(G) == "Undirected graph on 3 vertices and 2 edges"

0 commit comments

Comments
 (0)