Source code for manim.animation.transform_matching_parts
"""Animations that try to transform Mobjects while keeping track of identical parts."""
from __future__ import annotations
__all__ = ["TransformMatchingShapes", "TransformMatchingTex"]
from typing import TYPE_CHECKING
import numpy as np
from manim.mobject.opengl.opengl_mobject import OpenGLGroup, OpenGLMobject
from manim.mobject.opengl.opengl_vectorized_mobject import OpenGLVGroup, OpenGLVMobject
from .._config import config
from ..constants import RendererType
from ..mobject.mobject import Group, Mobject
from ..mobject.types.vectorized_mobject import VGroup, VMobject
from .composition import AnimationGroup
from .fading import FadeIn, FadeOut
from .transform import FadeTransformPieces, Transform
if TYPE_CHECKING:
from ..scene.scene import Scene
[docs]
class TransformMatchingAbstractBase(AnimationGroup):
"""Abstract base class for transformations that keep track of matching parts.
Subclasses have to implement the two static methods
:meth:`~.TransformMatchingAbstractBase.get_mobject_parts` and
:meth:`~.TransformMatchingAbstractBase.get_mobject_key`.
Basically, this transformation first maps all submobjects returned
by the ``get_mobject_parts`` method to certain keys by applying the
``get_mobject_key`` method. Then, submobjects with matching keys
are transformed into each other.
Parameters
----------
mobject
The starting :class:`~.Mobject`.
target_mobject
The target :class:`~.Mobject`.
transform_mismatches
Controls whether submobjects without a matching key are transformed
into each other by using :class:`~.Transform`. Default: ``False``.
fade_transform_mismatches
Controls whether submobjects without a matching key are transformed
into each other by using :class:`~.FadeTransform`. Default: ``False``.
key_map
Optional. A dictionary mapping keys belonging to some of the starting mobject's
submobjects (i.e., the return values of the ``get_mobject_key`` method)
to some keys belonging to the target mobject's submobjects that should
be transformed although the keys don't match.
kwargs
All further keyword arguments are passed to the submobject transformations.
Note
----
If neither ``transform_mismatches`` nor ``fade_transform_mismatches``
are set to ``True``, submobjects without matching keys in the starting
mobject are faded out in the direction of the unmatched submobjects in
the target mobject, and unmatched submobjects in the target mobject
are faded in from the direction of the unmatched submobjects in the
start mobject.
"""
def __init__(
self,
mobject: Mobject,
target_mobject: Mobject,
transform_mismatches: bool = False,
fade_transform_mismatches: bool = False,
key_map: dict | None = None,
**kwargs,
):
if isinstance(mobject, OpenGLVMobject):
group_type = OpenGLVGroup
elif isinstance(mobject, OpenGLMobject):
group_type = OpenGLGroup
elif isinstance(mobject, VMobject):
group_type = VGroup
else:
group_type = Group
source_map = self.get_shape_map(mobject)
target_map = self.get_shape_map(target_mobject)
if key_map is None:
key_map = {}
# Create two mobjects whose submobjects all match each other
# according to whatever keys are used for source_map and
# target_map
transform_source = group_type()
transform_target = group_type()
kwargs["final_alpha_value"] = 0
for key in set(source_map).intersection(target_map):
transform_source.add(source_map[key])
transform_target.add(target_map[key])
anims = [Transform(transform_source, transform_target, **kwargs)]
# User can manually specify when one part should transform
# into another despite not matching by using key_map
key_mapped_source = group_type()
key_mapped_target = group_type()
for key1, key2 in key_map.items():
if key1 in source_map and key2 in target_map:
key_mapped_source.add(source_map[key1])
key_mapped_target.add(target_map[key2])
source_map.pop(key1, None)
target_map.pop(key2, None)
if len(key_mapped_source) > 0:
anims.append(
FadeTransformPieces(key_mapped_source, key_mapped_target, **kwargs),
)
fade_source = group_type()
fade_target = group_type()
for key in set(source_map).difference(target_map):
fade_source.add(source_map[key])
for key in set(target_map).difference(source_map):
fade_target.add(target_map[key])
fade_target_copy = fade_target.copy()
if transform_mismatches:
if "replace_mobject_with_target_in_scene" not in kwargs:
kwargs["replace_mobject_with_target_in_scene"] = True
anims.append(Transform(fade_source, fade_target, **kwargs))
elif fade_transform_mismatches:
anims.append(FadeTransformPieces(fade_source, fade_target, **kwargs))
else:
anims.append(FadeOut(fade_source, target_position=fade_target, **kwargs))
anims.append(
FadeIn(fade_target_copy, target_position=fade_target, **kwargs),
)
super().__init__(*anims)
self.to_remove = [mobject, fade_target_copy]
self.to_add = target_mobject
def get_shape_map(self, mobject: Mobject) -> dict:
shape_map = {}
for sm in self.get_mobject_parts(mobject):
key = self.get_mobject_key(sm)
if key not in shape_map:
if config["renderer"] == RendererType.OPENGL:
shape_map[key] = OpenGLVGroup()
else:
shape_map[key] = VGroup()
shape_map[key].add(sm)
return shape_map
[docs]
def clean_up_from_scene(self, scene: Scene) -> None:
# Interpolate all animations back to 0 to ensure source mobjects remain unchanged.
for anim in self.animations:
anim.interpolate(0)
scene.remove(self.mobject)
scene.remove(*self.to_remove)
scene.add(self.to_add)
@staticmethod
def get_mobject_parts(mobject: Mobject):
raise NotImplementedError("To be implemented in subclass.")
@staticmethod
def get_mobject_key(mobject: Mobject):
raise NotImplementedError("To be implemented in subclass.")
[docs]
class TransformMatchingShapes(TransformMatchingAbstractBase):
"""An animation trying to transform groups by matching the shape
of their submobjects.
Two submobjects match if the hash of their point coordinates after
normalization (i.e., after translation to the origin, fixing the submobject
height at 1 unit, and rounding the coordinates to three decimal places)
matches.
See also
--------
:class:`~.TransformMatchingAbstractBase`
Examples
--------
.. manim:: Anagram
class Anagram(Scene):
def construct(self):
src = Text("the morse code")
tar = Text("here come dots")
self.play(Write(src))
self.wait(0.5)
self.play(TransformMatchingShapes(src, tar, path_arc=PI/2))
self.wait(0.5)
"""
def __init__(
self,
mobject: Mobject,
target_mobject: Mobject,
transform_mismatches: bool = False,
fade_transform_mismatches: bool = False,
key_map: dict | None = None,
**kwargs,
):
super().__init__(
mobject,
target_mobject,
transform_mismatches=transform_mismatches,
fade_transform_mismatches=fade_transform_mismatches,
key_map=key_map,
**kwargs,
)
@staticmethod
def get_mobject_parts(mobject: Mobject) -> list[Mobject]:
return mobject.family_members_with_points()
@staticmethod
def get_mobject_key(mobject: Mobject) -> int:
mobject.save_state()
mobject.center()
mobject.set(height=1)
result = hash(np.round(mobject.points, 3).tobytes())
mobject.restore()
return result
[docs]
class TransformMatchingTex(TransformMatchingAbstractBase):
"""A transformation trying to transform rendered LaTeX strings.
Two submobjects match if their ``tex_string`` matches.
See also
--------
:class:`~.TransformMatchingAbstractBase`
Examples
--------
.. manim:: MatchingEquationParts
class MatchingEquationParts(Scene):
def construct(self):
variables = VGroup(MathTex("a"), MathTex("b"), MathTex("c")).arrange_submobjects().shift(UP)
eq1 = MathTex("{{x}}^2", "+", "{{y}}^2", "=", "{{z}}^2")
eq2 = MathTex("{{a}}^2", "+", "{{b}}^2", "=", "{{c}}^2")
eq3 = MathTex("{{a}}^2", "=", "{{c}}^2", "-", "{{b}}^2")
self.add(eq1)
self.wait(0.5)
self.play(TransformMatchingTex(Group(eq1, variables), eq2))
self.wait(0.5)
self.play(TransformMatchingTex(eq2, eq3))
self.wait(0.5)
"""
def __init__(
self,
mobject: Mobject,
target_mobject: Mobject,
transform_mismatches: bool = False,
fade_transform_mismatches: bool = False,
key_map: dict | None = None,
**kwargs,
):
super().__init__(
mobject,
target_mobject,
transform_mismatches=transform_mismatches,
fade_transform_mismatches=fade_transform_mismatches,
key_map=key_map,
**kwargs,
)
@staticmethod
def get_mobject_parts(mobject: Mobject) -> list[Mobject]:
if isinstance(mobject, (Group, VGroup, OpenGLGroup, OpenGLVGroup)):
return [
p
for s in mobject.submobjects
for p in TransformMatchingTex.get_mobject_parts(s)
]
else:
assert hasattr(mobject, "tex_string")
return mobject.submobjects
@staticmethod
def get_mobject_key(mobject: Mobject) -> str:
return mobject.tex_string