If you'd like to reproduce this interactive plot in Python with PyTorch, here is the code:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.distributions import Beta
from torch.distributions.kl import kl_divergence
from ipywidgets import interactive, FloatSlider, Layout
from IPython.display import display
def plot_beta_interactive(a=1.0, b=1.0, a2=2.0, b2=2.0):
# Create Beta distributions
distribution1 = Beta(torch.tensor(a), torch.tensor(b))
distribution2 = Beta(torch.tensor(a2), torch.tensor(b2))
# Create x values
x = torch.linspace(0, 1, 1000)
# Calculate PDFs
pdf1 = torch.exp(distribution1.log_prob(x))
pdf2 = torch.exp(distribution2.log_prob(x))
# Convert to numpy
x_np = x.numpy()
pdf1_np = pdf1.numpy()
pdf2_np = pdf2.numpy()
# Calculate KL divergence
kl = kl_divergence(distribution1, distribution2).item()
# Create the plot
plt.figure(figsize=(10, 6))
plt.plot(x_np, pdf1_np, label=f'Beta(α={{a:.2f}}, β={{b:.2f}})')
plt.fill_between(x_np, pdf1_np, alpha=0.3)
plt.plot(x_np, pdf2_np, label=f'Beta(α={{a2:.2f}}, β={{b2:.2f}})')
plt.fill_between(x_np, pdf2_np, alpha=0.3)
# Add KL divergence to the plot
plt.text(0.4, 0.85, f'KL(P||Q) = {{kl:.4f}}', transform=plt.gca().transAxes,
verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
plt.xlabel('x')
plt.ylabel('Probability Density')
plt.title(f'Beta Distributions: (α={{a:.2f}}, β={{b:.2f}}) and (α={{a2:.2f}}, β={{b2:.2f}})')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
# Create interactive plot
def create_interactive_beta_plot():
slider_layout = Layout(width='500px')
interactive_plot = interactive(
plot_beta_interactive,
a=FloatSlider(min=0.001, max=100, step=0.001, value=1, description='a:', layout=slider_layout),
b=FloatSlider(min=0.001, max=100, step=0.001, value=1, description='b:', layout=slider_layout),
a2=FloatSlider(min=0.001, max=100, step=0.001, value=2, description='a2:', layout=slider_layout),
b2=FloatSlider(min=0.001, max=100, step=0.001, value=2, description='b2:', layout=slider_layout)
)
display(interactive_plot)
# Finally, run:
create_interactive_beta_plot()