diff --git a/manim/animation/animation.py b/manim/animation/animation.py index 3f2c4d35fd..c597a0bdc0 100644 --- a/manim/animation/animation.py +++ b/manim/animation/animation.py @@ -134,6 +134,9 @@ def __init__( name: str = None, remover: bool = False, # remove a mobject from the screen? suspend_mobject_updating: bool = True, + introducer: bool = False, + *, + _on_finish: Callable[[], None] = lambda _: None, **kwargs, ) -> None: self._typecheck_input(mobject) @@ -141,8 +144,10 @@ def __init__( self.rate_func: Callable[[float], float] = rate_func self.name: str | None = name self.remover: bool = remover + self.introducer: bool = introducer self.suspend_mobject_updating: bool = suspend_mobject_updating self.lag_ratio: float = lag_ratio + self._on_finish: Callable[[Scene], None] = _on_finish if config["renderer"] == "opengl": self.starting_mobject: OpenGLMobject = OpenGLMobject() self.mobject: OpenGLMobject = ( @@ -219,9 +224,26 @@ def clean_up_from_scene(self, scene: Scene) -> None: scene The scene the animation should be cleaned up from. """ + self._on_finish(scene) if self.is_remover(): scene.remove(self.mobject) + def _setup_scene(self, scene: Scene) -> None: + """Setup up the :class:`~.Scene` before starting the animation. + + This includes to :meth:`~.Scene.add` the Animation's + :class:`~.Mobject` if the animation is an introducer. + + Parameters + ---------- + scene + The scene the animation should be cleaned up from. + """ + if scene is None: + return + if self.is_introducer(): + scene.add(self.mobject) + def create_starting_mobject(self) -> Mobject: # Keep track of where the mobject starts return self.mobject.copy() @@ -436,6 +458,16 @@ def is_remover(self) -> bool: """ return self.remover + def is_introducer(self) -> bool: + """Test if a the animation is a remover. + + Returns + ------- + bool + ``True`` if the animation is a remover, ``False`` otherwise. + """ + return self.introducer + def prepare_animation( anim: Animation | mobject._AnimationBuilder, diff --git a/manim/animation/composition.py b/manim/animation/composition.py index d68b0e6ce8..4a1944404b 100644 --- a/manim/animation/composition.py +++ b/manim/animation/composition.py @@ -39,7 +39,7 @@ def __init__( self.group = group if self.group is None: mobjects = remove_list_redundancies( - [anim.mobject for anim in self.animations], + [anim.mobject for anim in self.animations if not anim.is_introducer()], ) if config["renderer"] == "opengl": self.group = OpenGLGroup(*mobjects) @@ -57,6 +57,10 @@ def begin(self) -> None: for anim in self.animations: anim.begin() + def _setup_scene(self, scene) -> None: + for anim in self.animations: + anim._setup_scene(scene) + def finish(self) -> None: for anim in self.animations: anim.finish() @@ -64,6 +68,7 @@ def finish(self) -> None: self.group.resume_updating() def clean_up_from_scene(self, scene: Scene) -> None: + self._on_finish(scene) for anim in self.animations: if self.remover: anim.remover = self.remover @@ -127,6 +132,16 @@ def update_mobjects(self, dt: float) -> None: if self.active_animation: self.active_animation.update_mobjects(dt) + def _setup_scene(self, scene) -> None: + if scene is None: + return + if self.is_introducer(): + for anim in self.animations: + if not anim.is_introducer() and anim.mobject is not None: + scene.add(anim.mobject) + + self.scene = scene + def update_active_animation(self, index: int) -> None: self.active_index = index if index >= len(self.animations): @@ -135,6 +150,7 @@ def update_active_animation(self, index: int) -> None: self.active_end_time: float | None = None else: self.active_animation = self.animations[index] + self.active_animation._setup_scene(self.scene) self.active_animation.begin() self.active_start_time = self.anims_with_timings[index][1] self.active_end_time = self.anims_with_timings[index][2] diff --git a/manim/animation/creation.py b/manim/animation/creation.py index 61d21f2773..6c592fbd47 100644 --- a/manim/animation/creation.py +++ b/manim/animation/creation.py @@ -115,7 +115,9 @@ class ShowPartial(Animation): """ def __init__( - self, mobject: VMobject | OpenGLVMobject | OpenGLSurface | None, **kwargs + self, + mobject: VMobject | OpenGLVMobject | OpenGLSurface | None, + **kwargs, ): pointwise = getattr(mobject, "pointwise_become_partial", None) if not callable(pointwise): @@ -167,9 +169,10 @@ def __init__( self, mobject: VMobject | OpenGLVMobject | OpenGLSurface, lag_ratio: float = 1.0, + introducer: bool = True, **kwargs, ) -> None: - super().__init__(mobject, lag_ratio=lag_ratio, **kwargs) + super().__init__(mobject, lag_ratio=lag_ratio, introducer=introducer, **kwargs) def _get_bounds(self, alpha: float) -> tuple[int, float]: return (0, alpha) @@ -199,7 +202,13 @@ def __init__( remover: bool = True, **kwargs, ) -> None: - super().__init__(mobject, rate_func=rate_func, remover=remover, **kwargs) + super().__init__( + mobject, + rate_func=rate_func, + introducer=False, + remover=remover, + **kwargs, + ) class DrawBorderThenFill(Animation): @@ -223,10 +232,17 @@ def __init__( stroke_color: str = None, draw_border_animation_config: dict = {}, # what does this dict accept? fill_animation_config: dict = {}, + introducer: bool = True, **kwargs, ) -> None: self._typecheck_input(vmobject) - super().__init__(vmobject, run_time=run_time, rate_func=rate_func, **kwargs) + super().__init__( + vmobject, + run_time=run_time, + introducer=introducer, + rate_func=rate_func, + **kwargs, + ) self.stroke_width = stroke_width self.stroke_color = stroke_color self.draw_border_animation_config = draw_border_animation_config @@ -308,11 +324,14 @@ def __init__( lag_ratio, ) self.reverse = reverse + if "remover" not in kwargs: + kwargs["remover"] = reverse super().__init__( vmobject, rate_func=rate_func, run_time=run_time, lag_ratio=lag_ratio, + introducer=not reverse, **kwargs, ) diff --git a/manim/animation/fading.py b/manim/animation/fading.py index 8ee3d551bf..6aaba7bf5e 100644 --- a/manim/animation/fading.py +++ b/manim/animation/fading.py @@ -137,6 +137,9 @@ def construct(self): """ + def __init__(self, *mobjects: Mobject, **kwargs) -> None: + super().__init__(*mobjects, introducer=True, **kwargs) + def create_target(self): return self.mobject diff --git a/manim/animation/growing.py b/manim/animation/growing.py index 137d0c21b7..4011eb02b2 100644 --- a/manim/animation/growing.py +++ b/manim/animation/growing.py @@ -79,7 +79,7 @@ def __init__( ) -> None: self.point = point self.point_color = point_color - super().__init__(mobject, **kwargs) + super().__init__(mobject, introducer=True, **kwargs) def create_target(self) -> Mobject: return self.mobject diff --git a/manim/animation/indication.py b/manim/animation/indication.py index fc2d4de2e6..1eeb5ed8f6 100644 --- a/manim/animation/indication.py +++ b/manim/animation/indication.py @@ -300,7 +300,7 @@ def construct(self): def __init__(self, mobject: "VMobject", time_width: float = 0.1, **kwargs) -> None: self.time_width = time_width - super().__init__(mobject, remover=True, **kwargs) + super().__init__(mobject, remover=True, introducer=True, **kwargs) def _get_bounds(self, alpha: float) -> Tuple[float]: tw = self.time_width @@ -310,8 +310,8 @@ def _get_bounds(self, alpha: float) -> Tuple[float]: lower = max(lower, 0) return (lower, upper) - def finish(self) -> None: - super().finish() + def clean_up_from_scene(self, scene: "Scene") -> None: + super().clean_up_from_scene(scene) for submob, start in self.get_all_families_zipped(): submob.pointwise_become_partial(start, 0, 1) diff --git a/manim/mobject/graph.py b/manim/mobject/graph.py index a9f803b86f..379a4db1de 100644 --- a/manim/mobject/graph.py +++ b/manim/mobject/graph.py @@ -7,7 +7,7 @@ ] from copy import copy -from typing import Hashable, List, Optional, Tuple, Type, Union +from typing import Hashable, Iterable, List, Optional, Tuple, Type, Union import networkx as nx import numpy as np @@ -499,7 +499,7 @@ def __getitem__(self: Graph, v: Hashable) -> Mobject: def __repr__(self: Graph) -> str: return f"Graph on {len(self.vertices)} vertices and {len(self.edges)} edges" - def _add_vertex( + def _create_vertex( self, vertex: Hashable, position: np.ndarray | None = None, @@ -508,34 +508,7 @@ def _add_vertex( vertex_type: type[Mobject] = Dot, vertex_config: dict | None = None, vertex_mobject: dict | None = None, - ) -> Mobject: - """Add a vertex to the graph. - - Parameters - ---------- - - vertex - A hashable vertex identifier. - position - The coordinates where the new vertex should be added. If ``None``, the center - of the graph is used. - label - Controls whether or not the vertex is labeled. If ``False`` (the default), - the vertex is not labeled; if ``True`` it is labeled using its - names (as specified in ``vertex``) via :class:`~.MathTex`. Alternatively, - any :class:`~.Mobject` can be passed to be used as the label. - label_fill_color - Sets the fill color of the default labels generated when ``labels`` - is set to ``True``. Has no effect for other values of ``label``. - vertex_type - The mobject class used for displaying vertices in the scene. - vertex_config - A dictionary containing keyword arguments to be passed to - the class specified via ``vertex_type``. - vertex_mobject - The mobject to be used as the vertex. Overrides all other - vertex customization options. - """ + ) -> tuple[Hashable, np.ndarray, dict, Mobject]: if position is None: position = self.get_center() @@ -547,73 +520,116 @@ def _add_vertex( f"Vertex identifier '{vertex}' is already used for a vertex in this graph.", ) - self._graph.add_node(vertex) - self._layout[vertex] = position - if isinstance(label, (Mobject, OpenGLMobject)): - self._labels[vertex] = label + label = label elif label is True: - self._labels[vertex] = MathTex(vertex, fill_color=label_fill_color) + label = MathTex(vertex, fill_color=label_fill_color) + elif vertex in self._labels: + label = self._labels[vertex] + else: + label = None base_vertex_config = copy(self.default_vertex_config) base_vertex_config.update(vertex_config) vertex_config = base_vertex_config - if vertex in self._labels: - vertex_config["label"] = self._labels[vertex] + if label is not None: + vertex_config["label"] = label if vertex_type is Dot: vertex_type = LabeledDot - self._vertex_config[vertex] = vertex_config - if vertex_mobject is None: - self.vertices[vertex] = vertex_type(**vertex_config) - else: - self.vertices[vertex] = vertex_mobject + vertex_mobject = vertex_type(**vertex_config) + + vertex_mobject.move_to(position) + return (vertex, position, vertex_config, vertex_mobject) + + def _add_created_vertex( + self, + vertex: Hashable, + position: np.ndarray, + vertex_config: dict, + vertex_mobject: Mobject, + ) -> Mobject: + if vertex in self.vertices: + raise ValueError( + f"Vertex identifier '{vertex}' is already used for a vertex in this graph.", + ) + + self._graph.add_node(vertex) + self._layout[vertex] = position + + if "label" in vertex_config: + self._labels[vertex] = vertex_config["label"] + + self._vertex_config[vertex] = vertex_config + + self.vertices[vertex] = vertex_mobject self.vertices[vertex].move_to(position) self.add(self.vertices[vertex]) return self.vertices[vertex] - def add_vertices( - self: Graph, - *vertices: Hashable, - positions: dict | None = None, - labels: bool = False, + def _add_vertex( + self, + vertex: Hashable, + position: np.ndarray | None = None, + label: bool = False, label_fill_color: str = BLACK, vertex_type: type[Mobject] = Dot, vertex_config: dict | None = None, - vertex_mobjects: dict | None = None, - ): - """Add a list of vertices to the graph. + vertex_mobject: dict | None = None, + ) -> Mobject: + """Add a vertex to the graph. Parameters ---------- - vertices - Hashable vertex identifiers. - positions - A dictionary specifying the coordinates where the new vertices should be added. - If ``None``, all vertices are created at the center of the graph. - labels + vertex + A hashable vertex identifier. + position + The coordinates where the new vertex should be added. If ``None``, the center + of the graph is used. + label Controls whether or not the vertex is labeled. If ``False`` (the default), the vertex is not labeled; if ``True`` it is labeled using its names (as specified in ``vertex``) via :class:`~.MathTex`. Alternatively, any :class:`~.Mobject` can be passed to be used as the label. label_fill_color Sets the fill color of the default labels generated when ``labels`` - is set to ``True``. Has no effect for other values of ``labels``. + is set to ``True``. Has no effect for other values of ``label``. vertex_type The mobject class used for displaying vertices in the scene. vertex_config A dictionary containing keyword arguments to be passed to the class specified via ``vertex_type``. - vertex_mobjects - A dictionary whose keys are the vertex identifiers, and whose - values are mobjects that should be used as vertices. Overrides - all other vertex customization options. + vertex_mobject + The mobject to be used as the vertex. Overrides all other + vertex customization options. """ + return self._add_created_vertex( + *self._create_vertex( + vertex=vertex, + position=position, + label=label, + label_fill_color=label_fill_color, + vertex_type=vertex_type, + vertex_config=vertex_config, + vertex_mobject=vertex_mobject, + ) + ) + + def _create_vertices( + self: Graph, + *vertices: Hashable, + positions: dict | None = None, + labels: bool = False, + label_fill_color: str = BLACK, + vertex_type: type[Mobject] = Dot, + vertex_config: dict | None = None, + vertex_mobjects: dict | None = None, + ) -> Iterable[tuple[Hashable, np.ndarray, dict, Mobject]]: if positions is None: positions = {} if vertex_mobjects is None: @@ -646,7 +662,7 @@ def add_vertices( } return [ - self._add_vertex( + self._create_vertex( v, position=positions[v], label=labels[v], @@ -658,6 +674,57 @@ def add_vertices( for v in vertices ] + def add_vertices( + self: Graph, + *vertices: Hashable, + positions: dict | None = None, + labels: bool = False, + label_fill_color: str = BLACK, + vertex_type: type[Mobject] = Dot, + vertex_config: dict | None = None, + vertex_mobjects: dict | None = None, + ): + """Add a list of vertices to the graph. + + Parameters + ---------- + + vertices + Hashable vertex identifiers. + positions + A dictionary specifying the coordinates where the new vertices should be added. + If ``None``, all vertices are created at the center of the graph. + labels + Controls whether or not the vertex is labeled. If ``False`` (the default), + the vertex is not labeled; if ``True`` it is labeled using its + names (as specified in ``vertex``) via :class:`~.MathTex`. Alternatively, + any :class:`~.Mobject` can be passed to be used as the label. + label_fill_color + Sets the fill color of the default labels generated when ``labels`` + is set to ``True``. Has no effect for other values of ``labels``. + vertex_type + The mobject class used for displaying vertices in the scene. + vertex_config + A dictionary containing keyword arguments to be passed to + the class specified via ``vertex_type``. + vertex_mobjects + A dictionary whose keys are the vertex identifiers, and whose + values are mobjects that should be used as vertices. Overrides + all other vertex customization options. + """ + return [ + self._add_created_vertex(*v) + for v in self._create_vertices( + *vertices, + positions=positions, + labels=labels, + label_fill_color=label_fill_color, + vertex_type=vertex_type, + vertex_config=vertex_config, + vertex_mobjects=vertex_mobjects, + ) + ] + @override_animate(add_vertices) def _add_vertices_animation(self, *args, anim_args=None, **kwargs): if anim_args is None: @@ -665,9 +732,17 @@ def _add_vertices_animation(self, *args, anim_args=None, **kwargs): animation = anim_args.pop("animation", Create) - vertex_mobjects = self.add_vertices(*args, **kwargs) + vertex_mobjects = self._create_vertices(*args, **kwargs) + + def on_finish(scene: Scene): + for v in vertex_mobjects: + scene.remove(v[-1]) + self._add_created_vertex(*v) + return AnimationGroup( - *(animation(v, **anim_args) for v in vertex_mobjects), group=self + *(animation(v[-1], **anim_args) for v in vertex_mobjects), + group=self, + _on_finish=on_finish, ) def _remove_vertex(self, vertex): diff --git a/manim/scene/scene.py b/manim/scene/scene.py index 26be794b92..6273204659 100644 --- a/manim/scene/scene.py +++ b/manim/scene/scene.py @@ -435,7 +435,7 @@ def add(self, *mobjects): mobjects = [*mobjects, *self.foreground_mobjects] self.restructure_mobjects(to_remove=mobjects) self.mobjects += mobjects - if self.moving_mobjects: + if self.moving_mobjects is not None: self.restructure_mobjects( to_remove=mobjects, mobject_list_name="moving_mobjects", @@ -446,6 +446,8 @@ def add(self, *mobjects): def add_mobjects_from_animations(self, animations): curr_mobjects = self.get_mobject_family_members() for animation in animations: + if animation.is_introducer(): + continue # Anything animated that's not already in the # scene gets added to the scene mob = animation.mobject @@ -1022,6 +1024,7 @@ def compile_animation_data(self, *animations: Animation, **play_kwargs): def begin_animations(self) -> None: """Start the animations of the scene.""" for animation in self.animations: + animation._setup_scene(self) animation.begin() def is_current_animation_frozen_frame(self) -> bool: diff --git a/manim/utils/testing/_frames_testers.py b/manim/utils/testing/_frames_testers.py index 498403a816..a277f88097 100644 --- a/manim/utils/testing/_frames_testers.py +++ b/manim/utils/testing/_frames_testers.py @@ -26,7 +26,7 @@ def testing(self): # For backward compatibility, when the control data contains only one frame (<= v0.8.0) if len(self._frames.shape) != 4: self._frames = np.expand_dims(self._frames, axis=0) - print(self._frames.shape) + logger.debug(self._frames.shape) self._number_frames = np.ma.size(self._frames, axis=0) yield assert self._frames_compared == self._number_frames, ( diff --git a/tests/opengl/test_composition_opengl.py b/tests/opengl/test_composition_opengl.py index 8f4580580c..9cb293ed41 100644 --- a/tests/opengl/test_composition_opengl.py +++ b/tests/opengl/test_composition_opengl.py @@ -1,5 +1,7 @@ from __future__ import annotations +from unittest.mock import Mock + from manim.animation.animation import Animation, Wait from manim.animation.composition import AnimationGroup, Succession from manim.animation.fading import FadeIn, FadeOut @@ -14,6 +16,7 @@ def test_succession_timing(using_opengl_renderer): animation_4s = FadeOut(line, shift=DOWN, run_time=4.0) succession = Succession(animation_1s, animation_4s) assert succession.get_run_time() == 5.0 + succession._setup_scene(Mock()) succession.begin() assert succession.active_index == 0 # The first animation takes 20% of the total run time. @@ -45,6 +48,7 @@ def test_succession_in_succession_timing(using_opengl_renderer): ) assert nested_succession.get_run_time() == 5.0 assert succession.get_run_time() == 10.0 + succession._setup_scene(Mock()) succession.begin() succession.interpolate(0.1) assert succession.active_index == 0 diff --git a/tests/test_composition.py b/tests/test_composition.py index 844e683a89..a1a5f22f17 100644 --- a/tests/test_composition.py +++ b/tests/test_composition.py @@ -1,5 +1,7 @@ from __future__ import annotations +from unittest.mock import Mock + import pytest from manim.animation.animation import Animation, Wait @@ -18,6 +20,7 @@ def test_succession_timing(): animation_4s = FadeOut(line, shift=DOWN, run_time=4.0) succession = Succession(animation_1s, animation_4s) assert succession.get_run_time() == 5.0 + succession._setup_scene(Mock()) succession.begin() assert succession.active_index == 0 # The first animation takes 20% of the total run time. @@ -49,6 +52,7 @@ def test_succession_in_succession_timing(): ) assert nested_succession.get_run_time() == 5.0 assert succession.get_run_time() == 10.0 + succession._setup_scene(Mock()) succession.begin() succession.interpolate(0.1) assert succession.active_index == 0 @@ -106,7 +110,8 @@ def test_animationgroup_with_wait(): @pytest.mark.parametrize( - "animation_remover, animation_group_remover", [(False, True), (True, False)] + "animation_remover, animation_group_remover", + [(False, True), (True, False)], ) def test_animationgroup_is_passing_remover_to_animations( animation_remover, animation_group_remover @@ -131,7 +136,9 @@ def test_animationgroup_is_passing_remover_to_nested_animationgroups(): circ_animation = Write(Circle(), remover=True) polygon_animation = Create(RegularPolygon(5)) animation_group = AnimationGroup( - AnimationGroup(sqr_animation, polygon_animation), circ_animation, remover=True + AnimationGroup(sqr_animation, polygon_animation), + circ_animation, + remover=True, ) scene.play(animation_group) diff --git a/tests/test_graphical_units/control_data/vector_scene/vector_to_coords.npz b/tests/test_graphical_units/control_data/vector_scene/vector_to_coords.npz index d473bdb36d..d7d6eb905d 100644 Binary files a/tests/test_graphical_units/control_data/vector_scene/vector_to_coords.npz and b/tests/test_graphical_units/control_data/vector_scene/vector_to_coords.npz differ