import tensorflow_probability.substrates.jax as tfp
import matplotlib.pyplot as plt
import seaborn as sns
import jax.numpy as jnp
import jax
%matplotlib inline
%config InlineBackend.figure_format='retina'

import matplotlib as mpl

mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.spines.top'] = False
dist = tfp.distributions
normal = dist.Normal(loc = 0., scale = 1.)
normal
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
<tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>

Properties of RV¶

normal.mean()
DeviceArray(0., dtype=float32)
normal.loc
DeviceArray(0., dtype=float32)
normal.scale
DeviceArray(1., dtype=float32)
normal.stddev()
DeviceArray(1., dtype=float32)
normal.entropy()
DeviceArray(1.4189385, dtype=float32)

Drawing samples¶

key = jax.random.PRNGKey(0)
sample = normal.sample(seed = key, sample_shape = [1000, ])
sns.displot(sample)
<seaborn.axisgrid.FacetGrid at 0x1afdc6d00>
../../_images/univariate-normal_11_1.png
sns.displot(sample, kind='kde',bw_adjust=2, rug=True)
<seaborn.axisgrid.FacetGrid at 0x1afe60280>
../../_images/univariate-normal_12_1.png

Finding pdf at a given point¶

normal.log_prob(1.)
DeviceArray(-1.4189385, dtype=float32)

Plotting pdf over given range¶

x = jnp.linspace(-5., 5., 100)

pdf_x = jnp.exp(normal.log_prob(x))
plt.plot(x, pdf_x)
plt.xlabel("x")
plt.ylabel("PDF")
Text(0, 0.5, 'PDF')
../../_images/univariate-normal_17_1.png

Finding cdf at a given point¶

plt.plot(x, normal.cdf(x))
[<matplotlib.lines.Line2D at 0x1b04bb850>]
../../_images/univariate-normal_19_1.png
perplexity = jnp.power(jnp.e, normal.entropy())
perplexity
DeviceArray(4.132731, dtype=float32)