import tensorflow_probability.substrates.jax as tfp
import matplotlib.pyplot as plt
import seaborn as sns
import jax.numpy as jnp
import jax
%matplotlib inline
%config InlineBackend.figure_format='retina'
import matplotlib as mpl
mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.spines.top'] = False
dist = tfp.distributions
normal = dist.Normal(loc = 0., scale = 1.)
normal
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
<tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>
Properties of RV#
normal.mean()
DeviceArray(0., dtype=float32)
normal.loc
DeviceArray(0., dtype=float32)
normal.scale
DeviceArray(1., dtype=float32)
normal.stddev()
DeviceArray(1., dtype=float32)
normal.entropy()
DeviceArray(1.4189385, dtype=float32)
Drawing samples#
key = jax.random.PRNGKey(0)
sample = normal.sample(seed = key, sample_shape = [1000, ])
sns.displot(sample)
<seaborn.axisgrid.FacetGrid at 0x1afdc6d00>
sns.displot(sample, kind='kde',bw_adjust=2, rug=True)
<seaborn.axisgrid.FacetGrid at 0x1afe60280>
Finding pdf at a given point#
normal.log_prob(1.)
DeviceArray(-1.4189385, dtype=float32)
Plotting pdf over given range#
x = jnp.linspace(-5., 5., 100)
pdf_x = jnp.exp(normal.log_prob(x))
plt.plot(x, pdf_x)
plt.xlabel("x")
plt.ylabel("PDF")
Text(0, 0.5, 'PDF')
Finding cdf at a given point#
plt.plot(x, normal.cdf(x))
[<matplotlib.lines.Line2D at 0x1b04bb850>]
perplexity = jnp.power(jnp.e, normal.entropy())
perplexity
DeviceArray(4.132731, dtype=float32)