Skip to content

Commit b3c079d

Browse files
leotrsPgBiel
andauthored
Refactor batch_by_property (#326)
* Remove batch_by_property in favor of itertools.groupby. Make display_funcs an instance variable rather than a variable local to capture_mobjects. Turn get_mobject_type into a method instead of a function local to capture_mobjects. I also fixed a typo in a method name. * typo * thanks black * some doc fixes * Update manim/camera/camera.py Co-authored-by: Pg Biel <[email protected]> * Update manim/camera/camera.py Co-authored-by: Pg Biel <[email protected]> * Update manim/camera/camera.py Co-authored-by: Pg Biel <[email protected]> * Apply suggestions from code review Co-authored-by: Pg Biel <[email protected]> * Update manim/camera/camera.py Co-authored-by: Pg Biel <[email protected]> Co-authored-by: Pg Biel <[email protected]>
1 parent d49de94 commit b3c079d

File tree

2 files changed

+81
-61
lines changed

2 files changed

+81
-61
lines changed

manim/camera/camera.py

Lines changed: 81 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from ..utils.color import color_to_int_rgba
2020
from ..utils.config_ops import digest_config
2121
from ..utils.images import get_full_raster_image_path
22-
from ..utils.iterables import batch_by_property
2322
from ..utils.iterables import list_difference_update
2423
from ..utils.iterables import remove_list_redundancies
2524
from ..utils.simple_functions import fdiv
@@ -80,6 +79,18 @@ def __init__(self, background=None, **kwargs):
8079
digest_config(self, kwargs, locals())
8180
self.rgb_max_val = np.iinfo(self.pixel_array_dtype).max
8281
self.pixel_array_to_cairo_context = {}
82+
83+
# Contains the correct method to process a list of Mobjects of the
84+
# corresponding class. If a Mobject is not an instance of a class in
85+
# this dict (or an instance of a class that inherits from a class in
86+
# this dict), then it cannot be rendered.
87+
self.display_funcs = {
88+
VMobject: self.display_multiple_vectorized_mobjects,
89+
PMobject: self.display_multiple_point_cloud_mobjects,
90+
AbstractImageMobject: self.display_multiple_image_mobjects,
91+
Mobject: lambda batch, pa: batch, # Do nothing
92+
}
93+
8394
self.init_background()
8495
self.resize_frame_shape()
8596
self.reset()
@@ -91,6 +102,42 @@ def __deepcopy__(self, memo):
91102
self.canvas = None
92103
return copy.copy(self)
93104

105+
def type_or_raise(self, mobject):
106+
"""Return the type of mobject, if it is a type that can be rendered.
107+
108+
If `mobject` is an instance of a class that inherits from a class that
109+
can be rendered, return the super class. For example, an instance of a
110+
Square is also an instance of VMobject, and these can be rendered.
111+
Therefore, `type_or_raise(Square())` returns True.
112+
113+
Parameters
114+
----------
115+
mobject : :class:`~.Mobject`
116+
The object to take the type of.
117+
118+
Notes
119+
-----
120+
For a list of classes that can currently be rendered, see :meth:`display_funcs`.
121+
122+
Returns
123+
-------
124+
Type[:class:`~.Mobject`]
125+
The type of mobjects, if it can be rendered.
126+
127+
Raises
128+
------
129+
:exc:`TypeError`
130+
When mobject is not an instance of a class that can be rendered.
131+
"""
132+
# We have to check each type in turn because we are dealing with
133+
# super classes. For example, if square = Square(), then
134+
# type(square) != VMobject, but isinstance(square, VMobject) == True.
135+
for _type in self.display_funcs:
136+
if isinstance(mobject, _type):
137+
return _type
138+
else:
139+
raise TypeError(f"Displaying an object of class {_type} is not supported")
140+
94141
def reset_pixel_shape(self, new_height, new_width):
95142
"""This method resets the height and width
96143
of a single pixel to the passed new_heigh and new_width.
@@ -381,34 +428,35 @@ def capture_mobject(
381428
): # TODO Write better docstrings for this method.
382429
return self.capture_mobjects([mobject], **kwargs)
383430

384-
def capture_mobjects(
385-
self, mobjects, **kwargs
386-
): # TODO Write better docstrings for this method.
387-
mobjects = self.get_mobjects_to_display(mobjects, **kwargs)
431+
def capture_mobjects(self, mobjects, **kwargs):
432+
"""Capture mobjects by printing them on :attr:`pixel_array`.
433+
434+
This is the essential function that converts the contents of a Scene
435+
into an array, which is then converted to an image or video.
436+
437+
Parameters
438+
----------
439+
mobjects : :class:`~.Mobject`
440+
Mobjects to capture.
441+
442+
kwargs : Any
443+
Keyword arguments to be passed to :meth:`get_mobjects_to_display`.
388444
389-
# Organize this list into batches of the same type, and
390-
# apply corresponding function to those batches
391-
type_func_pairs = [
392-
(VMobject, self.display_multiple_vectorized_mobjects),
393-
(PMobject, self.display_multiple_point_cloud_mobjects),
394-
(AbstractImageMobject, self.display_multiple_image_mobjects),
395-
(Mobject, lambda batch, pa: batch), # Do nothing
396-
]
397-
398-
def get_mobject_type(mobject):
399-
for mobject_type, func in type_func_pairs:
400-
if isinstance(mobject, mobject_type):
401-
return mobject_type
402-
raise Exception("Trying to display something which is not of type Mobject")
403-
404-
batch_type_pairs = batch_by_property(mobjects, get_mobject_type)
405-
406-
# Display in these batches
407-
for batch, batch_type in batch_type_pairs:
408-
# check what the type is, and call the appropriate function
409-
for mobject_type, func in type_func_pairs:
410-
if batch_type == mobject_type:
411-
func(batch, self.pixel_array)
445+
Notes
446+
-----
447+
For a list of classes that can currently be rendered, see :meth:`display_funcs`.
448+
449+
"""
450+
# The mobjects will be processed in batches (or runs) of mobjects of
451+
# the same type. That is, if the list mobjects contains objects of
452+
# types [VMobject, VMobject, VMobject, PMobject, PMobject, VMobject],
453+
# then they will be captured in three batches: [VMobject, VMobject,
454+
# VMobject], [PMobject, PMobject], and [VMobject]. This must be done
455+
# without altering their order. it.groupby computes exactly this
456+
# partition while at the same time preserving order.
457+
mobjects = self.get_mobjects_to_display(mobjects, **kwargs)
458+
for group_type, group in it.groupby(mobjects, self.type_or_raise):
459+
self.display_funcs[group_type](list(group), self.pixel_array)
412460

413461
# Methods associated with svg rendering
414462

@@ -497,12 +545,12 @@ def display_multiple_vectorized_mobjects(self, vmobjects, pixel_array):
497545
"""
498546
if len(vmobjects) == 0:
499547
return
500-
batch_file_pairs = batch_by_property(
548+
batch_file_pairs = it.groupby(
501549
vmobjects, lambda vm: vm.get_background_image_file()
502550
)
503-
for batch, file_name in batch_file_pairs:
551+
for file_name, batch in batch_file_pairs:
504552
if file_name:
505-
self.display_multiple_background_colored_vmobject(batch, pixel_array)
553+
self.display_multiple_background_colored_vmobjects(batch, pixel_array)
506554
else:
507555
self.display_multiple_non_background_colored_vmobjects(
508556
batch, pixel_array
@@ -714,7 +762,7 @@ def get_background_colored_vmobject_displayer(self):
714762
setattr(self, bcvd, BackgroundColoredVMobjectDisplayer(self))
715763
return getattr(self, bcvd)
716764

717-
def display_multiple_background_colored_vmobject(self, cvmobjects, pixel_array):
765+
def display_multiple_background_colored_vmobjects(self, cvmobjects, pixel_array):
718766
"""Displays multiple vmobjects that have the same color as the background.
719767
720768
Parameters
@@ -1173,7 +1221,7 @@ def display(self, *cvmobjects):
11731221
np.array
11741222
The pixel array with the `cvmobjects` displayed.
11751223
"""
1176-
batch_image_file_pairs = batch_by_property(
1224+
batch_image_file_pairs = it.groupby(
11771225
cvmobjects, lambda cv: cv.get_background_image_file()
11781226
)
11791227
curr_array = None

manim/utils/iterables.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -41,34 +41,6 @@ def adjacent_pairs(objects):
4141
return adjacent_n_tuples(objects, 2)
4242

4343

44-
def batch_by_property(items, property_func):
45-
"""
46-
Takes in a list, and returns a list of tuples, (batch, prop)
47-
such that all items in a batch have the same output when
48-
put into property_func, and such that chaining all these
49-
batches together would give the original list (i.e. order is
50-
preserved)
51-
"""
52-
batch_prop_pairs = []
53-
54-
def add_batch_prop_pair(batch):
55-
if len(batch) > 0:
56-
batch_prop_pairs.append((batch, property_func(batch[0])))
57-
58-
curr_batch = []
59-
curr_prop = None
60-
for item in items:
61-
prop = property_func(item)
62-
if prop != curr_prop:
63-
add_batch_prop_pair(curr_batch)
64-
curr_prop = prop
65-
curr_batch = [item]
66-
else:
67-
curr_batch.append(item)
68-
add_batch_prop_pair(curr_batch)
69-
return batch_prop_pairs
70-
71-
7244
def tuplify(obj):
7345
if isinstance(obj, str):
7446
return (obj,)

0 commit comments

Comments
 (0)