import tensorflow_probability.substrates.jax as tfp
import matplotlib.pyplot as plt
import seaborn as sns
import jax.numpy as jnp
import jax
%matplotlib inline
%config InlineBackend.figure_format='retina'

import matplotlib as mpl

mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.spines.top'] = False
dist = tfp.distributions
normal = dist.Normal(loc = 0., scale = 1.)
normal
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
<tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>

Properties of RV#

normal.mean()
DeviceArray(0., dtype=float32)
normal.loc
DeviceArray(0., dtype=float32)
normal.scale
DeviceArray(1., dtype=float32)
normal.stddev()
DeviceArray(1., dtype=float32)
normal.entropy()
DeviceArray(1.4189385, dtype=float32)

Drawing samples#

key = jax.random.PRNGKey(0)
sample = normal.sample(seed = key, sample_shape = [1000, ])
sns.displot(sample)
<seaborn.axisgrid.FacetGrid at 0x1afdc6d00>
../../_images/3e61490ee3378ca6dc524308d44bcb55bad24f90a959201a9baadb16482501dc.png
sns.displot(sample, kind='kde',bw_adjust=2, rug=True)
<seaborn.axisgrid.FacetGrid at 0x1afe60280>
../../_images/2857b929bf6dfac670aee5658ed91533ea3235732395df581813bfc40f386b18.png

Finding pdf at a given point#

normal.log_prob(1.)
DeviceArray(-1.4189385, dtype=float32)

Plotting pdf over given range#

x = jnp.linspace(-5., 5., 100)

pdf_x = jnp.exp(normal.log_prob(x))
plt.plot(x, pdf_x)
plt.xlabel("x")
plt.ylabel("PDF")
Text(0, 0.5, 'PDF')
../../_images/ec6864dd93e036aee69ef52340a1374dd90240f61127363a84cb79f4d9f2353f.png

Finding cdf at a given point#

plt.plot(x, normal.cdf(x))
[<matplotlib.lines.Line2D at 0x1b04bb850>]
../../_images/af60a4eb6dbab493ab669c8b2b97a91b333ba3e00c19a79fc1c8973d7d9de2c4.png
perplexity = jnp.power(jnp.e, normal.entropy())
perplexity
DeviceArray(4.132731, dtype=float32)