Source code for manim.mobject.probability

"""Mobjects representing objects from probability theory and statistics."""

__all__ = ["SampleSpace", "BarChart"]


from typing import Iterable, List

import numpy as np

from ..constants import *
from ..mobject.geometry import Line, Rectangle
from ..mobject.mobject import Mobject
from ..mobject.opengl_mobject import OpenGLMobject
from ..mobject.svg.brace import Brace
from ..mobject.svg.tex_mobject import MathTex, Tex
from ..mobject.types.vectorized_mobject import VGroup
from ..utils.color import (
    BLUE,
    BLUE_E,
    DARK_GREY,
    GREEN_E,
    LIGHT_GREY,
    MAROON_B,
    YELLOW,
    color_gradient,
)
from ..utils.iterables import tuplify

EPSILON = 0.0001


[docs]class SampleSpace(Rectangle): """ Examples -------- .. manim:: ExampleSampleSpace :save_last_frame: class ExampleSampleSpace(Scene): def construct(self): poly1 = SampleSpace(stroke_width=15, fill_opacity=1) poly2 = SampleSpace(width=5, height=3, stroke_width=5, fill_opacity=0.5) poly3 = SampleSpace(width=2, height=2, stroke_width=5, fill_opacity=0.1) poly3.divide_vertically(p_list=np.array([0.37, 0.13, 0.5]), colors=[BLACK, WHITE, GRAY], vect=RIGHT) poly_group = VGroup(poly1, poly2, poly3).arrange() self.add(poly_group) """ def __init__( self, height=3, width=3, fill_color=DARK_GREY, fill_opacity=1, stroke_width=0.5, stroke_color=LIGHT_GREY, default_label_scale_val=1, ): super().__init__( height=height, width=width, fill_color=fill_color, fill_opacity=fill_opacity, stroke_width=stroke_width, stroke_color=stroke_color, ) self.default_label_scale_val = default_label_scale_val def add_title(self, title="Sample space", buff=MED_SMALL_BUFF): # TODO, should this really exist in SampleSpaceScene title_mob = Tex(title) if title_mob.width > self.width: title_mob.width = self.width title_mob.next_to(self, UP, buff=buff) self.title = title_mob self.add(title_mob) def add_label(self, label): self.label = label def complete_p_list(self, p_list): new_p_list = list(tuplify(p_list)) remainder = 1.0 - sum(new_p_list) if abs(remainder) > EPSILON: new_p_list.append(remainder) return new_p_list def get_division_along_dimension(self, p_list, dim, colors, vect): p_list = self.complete_p_list(p_list) colors = color_gradient(colors, len(p_list)) last_point = self.get_edge_center(-vect) parts = VGroup() for factor, color in zip(p_list, colors): part = SampleSpace() part.set_fill(color, 1) part.replace(self, stretch=True) part.stretch(factor, dim) part.move_to(last_point, -vect) last_point = part.get_edge_center(vect) parts.add(part) return parts def get_horizontal_division(self, p_list, colors=[GREEN_E, BLUE_E], vect=DOWN): return self.get_division_along_dimension(p_list, 1, colors, vect) def get_vertical_division(self, p_list, colors=[MAROON_B, YELLOW], vect=RIGHT): return self.get_division_along_dimension(p_list, 0, colors, vect) def divide_horizontally(self, *args, **kwargs): self.horizontal_parts = self.get_horizontal_division(*args, **kwargs) self.add(self.horizontal_parts) def divide_vertically(self, *args, **kwargs): self.vertical_parts = self.get_vertical_division(*args, **kwargs) self.add(self.vertical_parts) def get_subdivision_braces_and_labels( self, parts, labels, direction, buff=SMALL_BUFF, min_num_quads=1, ): label_mobs = VGroup() braces = VGroup() for label, part in zip(labels, parts): brace = Brace(part, direction, min_num_quads=min_num_quads, buff=buff) if isinstance(label, (Mobject, OpenGLMobject)): label_mob = label else: label_mob = MathTex(label) label_mob.scale(self.default_label_scale_val) label_mob.next_to(brace, direction, buff) braces.add(brace) label_mobs.add(label_mob) parts.braces = braces parts.labels = label_mobs parts.label_kwargs = { "labels": label_mobs.copy(), "direction": direction, "buff": buff, } return VGroup(parts.braces, parts.labels) def get_side_braces_and_labels(self, labels, direction=LEFT, **kwargs): assert hasattr(self, "horizontal_parts") parts = self.horizontal_parts return self.get_subdivision_braces_and_labels( parts, labels, direction, **kwargs ) def get_top_braces_and_labels(self, labels, **kwargs): assert hasattr(self, "vertical_parts") parts = self.vertical_parts return self.get_subdivision_braces_and_labels(parts, labels, UP, **kwargs) def get_bottom_braces_and_labels(self, labels, **kwargs): assert hasattr(self, "vertical_parts") parts = self.vertical_parts return self.get_subdivision_braces_and_labels(parts, labels, DOWN, **kwargs) def add_braces_and_labels(self): for attr in "horizontal_parts", "vertical_parts": if not hasattr(self, attr): continue parts = getattr(self, attr) for subattr in "braces", "labels": if hasattr(parts, subattr): self.add(getattr(parts, subattr)) def __getitem__(self, index): if hasattr(self, "horizontal_parts"): return self.horizontal_parts[index] elif hasattr(self, "vertical_parts"): return self.vertical_parts[index] return self.split()[index]
[docs]class BarChart(VGroup): """This is a class for Bar Charts. Parameters ---------- values The values for the bar chart. height The height of the axes. width The width of the axes. n_ticks Number of ticks. tick_width Width of the ticks. label_y_axis Y axis label y_axis_label_height Height of the label. max_value Maximum value of the data. bar_colors The colors of the bars. bar_fill_opacity The opacity of the bars. bar_stroke_width The stroke width of the bars. bar_names The names of each bar. bar_label_scale_val The label size. Examples -------- .. manim:: BarChartExample :save_last_frame: class BarChartExample(Scene): def construct(self): pull_req = [54, 23, 47, 48, 40, 64, 112, 87] versions = [ "v0.1.0", "v0.1.1", "v0.2.0", "v0.3.0", "v0.4.0", "v0.5.0", "v0.6.0", "v0.7.0", ] colors = ["#003f5c", "#58508d", "#bc5090", "#ff6361", "#ffa600"] bar = BarChart( pull_req, max_value=max(pull_req), bar_colors=colors, bar_names=versions, bar_label_scale_val=0.3, ) self.add(bar) """ def __init__( self, values: Iterable[float], height: float = 4, width: float = 6, n_ticks: int = 4, tick_width: float = 0.2, label_y_axis: bool = True, y_axis_label_height: float = 0.25, max_value: float = 1, bar_colors=[BLUE, YELLOW], bar_fill_opacity: float = 0.8, bar_stroke_width: float = 3, bar_names: List[str] = [], bar_label_scale_val: float = 0.75, **kwargs ): # What's the return type? super().__init__(**kwargs) self.n_ticks = n_ticks self.tick_width = tick_width self.label_y_axis = label_y_axis self.y_axis_label_height = y_axis_label_height self.max_value = max_value self.bar_colors = bar_colors self.bar_fill_opacity = bar_fill_opacity self.bar_stroke_width = bar_stroke_width self.bar_names = bar_names self.bar_label_scale_val = bar_label_scale_val self.total_bar_width = width self.total_bar_height = height if self.max_value is None: self.max_value = max(values) self.add_axes() self.add_bars(values) self.center() def add_axes(self): x_axis = Line(self.tick_width * LEFT / 2, self.total_bar_width * RIGHT) y_axis = Line(MED_LARGE_BUFF * DOWN, self.total_bar_height * UP) ticks = VGroup() heights = np.linspace(0, self.total_bar_height, self.n_ticks + 1) values = np.linspace(0, self.max_value, self.n_ticks + 1) for y, _value in zip(heights, values): tick = Line(LEFT, RIGHT) tick.width = self.tick_width tick.move_to(y * UP) ticks.add(tick) y_axis.add(ticks) self.add(x_axis, y_axis) self.x_axis, self.y_axis = x_axis, y_axis if self.label_y_axis: labels = VGroup() for tick, value in zip(ticks, values): label = MathTex(str(np.round(value, 2))) label.height = self.y_axis_label_height label.next_to(tick, LEFT, SMALL_BUFF) labels.add(label) self.y_axis_labels = labels self.add(labels) def add_bars(self, values): buff = float(self.total_bar_width) / (2 * len(values) + 1) bars = VGroup() for i, value in enumerate(values): bar = Rectangle( height=(value / self.max_value) * self.total_bar_height, width=buff, stroke_width=self.bar_stroke_width, fill_opacity=self.bar_fill_opacity, ) bar.move_to((2 * i + 1) * buff * RIGHT, DOWN + LEFT) bars.add(bar) bars.set_color_by_gradient(*self.bar_colors) bar_labels = VGroup() for bar, name in zip(bars, self.bar_names): label = MathTex(str(name)) label.scale(self.bar_label_scale_val) label.next_to(bar, DOWN, SMALL_BUFF) bar_labels.add(label) self.add(bars, bar_labels) self.bars = bars self.bar_labels = bar_labels def change_bar_values(self, values): for bar, value in zip(self.bars, values): bar_bottom = bar.get_bottom() bar.stretch_to_fit_height((value / self.max_value) * self.total_bar_height) bar.move_to(bar_bottom, DOWN)