Source code for manim.animation.transform_matching_parts

"""Animations that try to transform Mobjects while keeping track of identical parts."""

from __future__ import annotations

__all__ = ["TransformMatchingShapes", "TransformMatchingTex"]

from typing import TYPE_CHECKING

import numpy as np

from manim.mobject.opengl.opengl_mobject import OpenGLGroup, OpenGLMobject
from manim.mobject.opengl.opengl_vectorized_mobject import OpenGLVGroup, OpenGLVMobject

from .._config import config
from ..constants import RendererType
from ..mobject.mobject import Group, Mobject
from ..mobject.types.vectorized_mobject import VGroup, VMobject
from .composition import AnimationGroup
from .fading import FadeIn, FadeOut
from .transform import FadeTransformPieces, Transform

if TYPE_CHECKING:
    from ..scene.scene import Scene


[docs]class TransformMatchingAbstractBase(AnimationGroup): """Abstract base class for transformations that keep track of matching parts. Subclasses have to implement the two static methods :meth:`~.TransformMatchingAbstractBase.get_mobject_parts` and :meth:`~.TransformMatchingAbstractBase.get_mobject_key`. Basically, this transformation first maps all submobjects returned by the ``get_mobject_parts`` method to certain keys by applying the ``get_mobject_key`` method. Then, submobjects with matching keys are transformed into each other. Parameters ---------- mobject The starting :class:`~.Mobject`. target_mobject The target :class:`~.Mobject`. transform_mismatches Controls whether submobjects without a matching key are transformed into each other by using :class:`~.Transform`. Default: ``False``. fade_transform_mismatches Controls whether submobjects without a matching key are transformed into each other by using :class:`~.FadeTransform`. Default: ``False``. key_map Optional. A dictionary mapping keys belonging to some of the starting mobject's submobjects (i.e., the return values of the ``get_mobject_key`` method) to some keys belonging to the target mobject's submobjects that should be transformed although the keys don't match. kwargs All further keyword arguments are passed to the submobject transformations. Note ---- If neither ``transform_mismatches`` nor ``fade_transform_mismatches`` are set to ``True``, submobjects without matching keys in the starting mobject are faded out in the direction of the unmatched submobjects in the target mobject, and unmatched submobjects in the target mobject are faded in from the direction of the unmatched submobjects in the start mobject. """ def __init__( self, mobject: Mobject, target_mobject: Mobject, transform_mismatches: bool = False, fade_transform_mismatches: bool = False, key_map: dict | None = None, **kwargs, ): if isinstance(mobject, OpenGLVMobject): group_type = OpenGLVGroup elif isinstance(mobject, OpenGLMobject): group_type = OpenGLGroup elif isinstance(mobject, VMobject): group_type = VGroup else: group_type = Group source_map = self.get_shape_map(mobject) target_map = self.get_shape_map(target_mobject) if key_map is None: key_map = {} # Create two mobjects whose submobjects all match each other # according to whatever keys are used for source_map and # target_map transform_source = group_type() transform_target = group_type() kwargs["final_alpha_value"] = 0 for key in set(source_map).intersection(target_map): transform_source.add(source_map[key]) transform_target.add(target_map[key]) anims = [Transform(transform_source, transform_target, **kwargs)] # User can manually specify when one part should transform # into another despite not matching by using key_map key_mapped_source = group_type() key_mapped_target = group_type() for key1, key2 in key_map.items(): if key1 in source_map and key2 in target_map: key_mapped_source.add(source_map[key1]) key_mapped_target.add(target_map[key2]) source_map.pop(key1, None) target_map.pop(key2, None) if len(key_mapped_source) > 0: anims.append( FadeTransformPieces(key_mapped_source, key_mapped_target, **kwargs), ) fade_source = group_type() fade_target = group_type() for key in set(source_map).difference(target_map): fade_source.add(source_map[key]) for key in set(target_map).difference(source_map): fade_target.add(target_map[key]) fade_target_copy = fade_target.copy() if transform_mismatches: if "replace_mobject_with_target_in_scene" not in kwargs: kwargs["replace_mobject_with_target_in_scene"] = True anims.append(Transform(fade_source, fade_target, **kwargs)) elif fade_transform_mismatches: anims.append(FadeTransformPieces(fade_source, fade_target, **kwargs)) else: anims.append(FadeOut(fade_source, target_position=fade_target, **kwargs)) anims.append( FadeIn(fade_target_copy, target_position=fade_target, **kwargs), ) super().__init__(*anims) self.to_remove = [mobject, fade_target_copy] self.to_add = target_mobject def get_shape_map(self, mobject: Mobject) -> dict: shape_map = {} for sm in self.get_mobject_parts(mobject): key = self.get_mobject_key(sm) if key not in shape_map: if config["renderer"] == RendererType.OPENGL: shape_map[key] = OpenGLVGroup() else: shape_map[key] = VGroup() shape_map[key].add(sm) return shape_map
[docs] def clean_up_from_scene(self, scene: Scene) -> None: # Interpolate all animations back to 0 to ensure source mobjects remain unchanged. for anim in self.animations: anim.interpolate(0) scene.remove(self.mobject) scene.remove(*self.to_remove) scene.add(self.to_add)
@staticmethod def get_mobject_parts(mobject: Mobject): raise NotImplementedError("To be implemented in subclass.") @staticmethod def get_mobject_key(mobject: Mobject): raise NotImplementedError("To be implemented in subclass.")
[docs]class TransformMatchingShapes(TransformMatchingAbstractBase): """An animation trying to transform groups by matching the shape of their submobjects. Two submobjects match if the hash of their point coordinates after normalization (i.e., after translation to the origin, fixing the submobject height at 1 unit, and rounding the coordinates to three decimal places) matches. See also -------- :class:`~.TransformMatchingAbstractBase` Examples -------- .. manim:: Anagram class Anagram(Scene): def construct(self): src = Text("the morse code") tar = Text("here come dots") self.play(Write(src)) self.wait(0.5) self.play(TransformMatchingShapes(src, tar, path_arc=PI/2)) self.wait(0.5) """ def __init__( self, mobject: Mobject, target_mobject: Mobject, transform_mismatches: bool = False, fade_transform_mismatches: bool = False, key_map: dict | None = None, **kwargs, ): super().__init__( mobject, target_mobject, transform_mismatches=transform_mismatches, fade_transform_mismatches=fade_transform_mismatches, key_map=key_map, **kwargs, ) @staticmethod def get_mobject_parts(mobject: Mobject) -> list[Mobject]: return mobject.family_members_with_points() @staticmethod def get_mobject_key(mobject: Mobject) -> int: mobject.save_state() mobject.center() mobject.set_height(1) result = hash(np.round(mobject.points, 3).tobytes()) mobject.restore() return result
[docs]class TransformMatchingTex(TransformMatchingAbstractBase): """A transformation trying to transform rendered LaTeX strings. Two submobjects match if their ``tex_string`` matches. See also -------- :class:`~.TransformMatchingAbstractBase` Examples -------- .. manim:: MatchingEquationParts class MatchingEquationParts(Scene): def construct(self): variables = VGroup(MathTex("a"), MathTex("b"), MathTex("c")).arrange_submobjects().shift(UP) eq1 = MathTex("{{x}}^2", "+", "{{y}}^2", "=", "{{z}}^2") eq2 = MathTex("{{a}}^2", "+", "{{b}}^2", "=", "{{c}}^2") eq3 = MathTex("{{a}}^2", "=", "{{c}}^2", "-", "{{b}}^2") self.add(eq1) self.wait(0.5) self.play(TransformMatchingTex(Group(eq1, variables), eq2)) self.wait(0.5) self.play(TransformMatchingTex(eq2, eq3)) self.wait(0.5) """ def __init__( self, mobject: Mobject, target_mobject: Mobject, transform_mismatches: bool = False, fade_transform_mismatches: bool = False, key_map: dict | None = None, **kwargs, ): super().__init__( mobject, target_mobject, transform_mismatches=transform_mismatches, fade_transform_mismatches=fade_transform_mismatches, key_map=key_map, **kwargs, ) @staticmethod def get_mobject_parts(mobject: Mobject) -> list[Mobject]: if isinstance(mobject, (Group, VGroup, OpenGLGroup, OpenGLVGroup)): return [ p for s in mobject.submobjects for p in TransformMatchingTex.get_mobject_parts(s) ] else: assert hasattr(mobject, "tex_string") return mobject.submobjects @staticmethod def get_mobject_key(mobject: Mobject) -> str: return mobject.tex_string