Metropolis Hastings#
import torch
dist = torch.distributions
import matplotlib.pyplot as plt
import pandas as pd
%matplotlib inline
import warnings
warnings.filterwarnings('ignore')
mix = dist.MixtureSameFamily(
mixture_distribution=dist.Categorical(torch.tensor([0.2, 0.8])),
component_distribution=dist.Normal(torch.tensor([0.1, 3]), torch.tensor([0.6, 3])),
)
x = torch.linspace(-4., 4, 5)
theta_range = torch.linspace(-8, 12, 100)
plt.plot(theta_range, mix.log_prob(theta_range).exp())
[<matplotlib.lines.Line2D at 0x165fa8c10>]
data:image/s3,"s3://crabby-images/89451/894511674a120a80745b37c3954cb759776d7030" alt="../../_images/23a7e888ad7f21789d727a58e9e75747062e4171a10ada421f90d73fca94731b.png"
mix.mean, mix.stddev
(tensor(2.4200), tensor(2.9356))
next_sample = lambda cur_sample: dist.Normal(loc = cur_sample, scale=1).sample().item()
next_sample(1)
2.0207135677337646
p = lambda x: mix.log_prob(torch.tensor(x)).exp().item()
lp = lambda x: mix.log_prob(torch.tensor(x)).item()
p(2)
0.10151920467615128
num_iter = 10
xs = [None]*num_iter
xs[0] = 0.
plt.plot(theta_range, mix.log_prob(theta_range).exp())
plt.scatter(xs[0], p(xs[0]))
xs[1] = next_sample(xs[0])
plt.scatter(xs[1], p(xs[1]))
<matplotlib.collections.PathCollection at 0x166017fa0>
data:image/s3,"s3://crabby-images/f5c71/f5c715e50dd3a12f853e3cba839609d8c7553b35" alt="../../_images/df1b2d70094a3b3272feaf8bb395b1dc741decf71c45752ff09e1afe7a059405.png"
num_iter = 10
xs = [None]*num_iter
xs[0] = 0.2
plt.plot(theta_range, mix.log_prob(theta_range).exp())
plt.scatter(xs[0], p(xs[0]))
xs[1] = next_sample(xs[0])
plt.scatter(xs[1], p(xs[1]))
a = p(xs[1])/p(xs[0])
a
0.27602488498335154
data:image/s3,"s3://crabby-images/28683/28683cba3c194b212e723b7de9c8770ba6ef91ed" alt="../../_images/a77cbe8dcf7f5bda041cb048a001a1d2ddeb7668c2a1926fc9f00e967c9d9ee1.png"
u = dist.Uniform(0, 1).sample().item()
print(u)
0.049220144748687744
if u < a:
xs[1] = xs[0]
x_start = 0.
num_iter = 20000
xs = torch.empty(num_iter)
xs[0] = x_start
lu = torch.log(dist.Uniform(0, 1).sample([num_iter]))
for i in range(1, num_iter):
xs[i] = next_sample(xs[i-1])
la = lp(xs[i]) - lp(xs[i-1])
if lu[i] > la:
xs[i] = xs[i-1]
plt.plot(xs)
[<matplotlib.lines.Line2D at 0x1660a9d90>]
data:image/s3,"s3://crabby-images/1c47c/1c47ce33788a091e20d767cddbb225c3cfa5385a" alt="../../_images/b097f332761114edda4423cc18885deb67a1c70f4060b5f39f9fdd08ce5dc538.png"
import seaborn as sns
plt.plot(theta_range, mix.log_prob(theta_range).exp(), label='True')
sns.kdeplot(torch.tensor(xs[:10000]), label='Samples obtained from MH')
plt.legend()
<matplotlib.legend.Legend at 0x16613f790>
data:image/s3,"s3://crabby-images/a87a5/a87a559ee69a279fcea5442c4fb4f420e18eaaeb" alt="../../_images/168ac40a8c14b798c141736032d435c28c4d168d88e996e6c5850aaf390624c3.png"
xs[:1000].mean(), mix.mean, xs[:1000].std(), mix.stddev
(tensor(3.5472), tensor(2.4200), tensor(3.0889), tensor(2.9356))
import numpy as np
g = emcee.moves.GaussianMove(cov = 1.)
import emcee
nwalkers = 2
ndim = 1
log_prob = lp
p0 = np.random.rand(nwalkers, ndim)
sampler = emcee.EnsembleSampler(nwalkers, ndim, log_prob,moves=g)
state = sampler.run_mcmc(p0, 100)
sampler.reset()
sampler.run_mcmc(state, 10000, progress=True);
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [00:03<00:00, 3105.40it/s]
samples = sampler.get_chain(flat=True)
plt.plot(samples)
[<matplotlib.lines.Line2D at 0x16619fd60>]
data:image/s3,"s3://crabby-images/01825/01825213360a53c954e78334114264ac6a13545c" alt="../../_images/83b65bfe4897f5206f53b6574b81dd9bd956c14a9290736dd2759e43dbe9aaaa.png"
sns.kdeplot(samples.flatten(), label='EmCEE')
plt.plot(theta_range, mix.log_prob(theta_range).exp(), label='True')
sns.kdeplot(torch.tensor(xs[:10000]), label='Samples obtained from MH')
plt.legend()
<matplotlib.legend.Legend at 0x1661b3490>
data:image/s3,"s3://crabby-images/2f6b0/2f6b02072ddb442801f967ea2a691b667f7d912f" alt="../../_images/04329d70632790d97293685dc5b461e1811527fe9fd99400bad07c7476bc4b1a.png"
Creating a function#
def mh(log_p, next_sample, num_iter, x_start):
lu = torch.log(dist.Uniform(0, 1).sample([num_iter]))
try:
l = len(x_start)
xs = torch.empty((num_iter, len(x_start)))
except:
xs = torch.empty(num_iter)
xs[0] = x_start
for i in range(1, num_iter):
xs[i] = next_sample(xs[i-1])
la = log_p(xs[i]) - log_p(xs[i-1])
if lu[i] > la:
xs[i] = xs[i-1]
return xs
xs = mh(lp, next_sample, 10000, 0.)
plt.plot(xs)
[<matplotlib.lines.Line2D at 0x166342280>]
data:image/s3,"s3://crabby-images/9c7ca/9c7ca4bae9d647196123b78bab9d1c3ba49a960b" alt="../../_images/9b18eb9baef587d2428b98c70b99fe89ba83663d5c313350ed518b3e4a434880.png"
2d#
dist_2d = dist.MultivariateNormal(loc = torch.zeros(2), covariance_matrix=torch.tensor([[1., 0.5], [0.5, 2.]]))
dist_2d
MultivariateNormal(loc: torch.Size([2]), covariance_matrix: torch.Size([2, 2]))
log_p_2d = lambda x: dist_2d.log_prob(torch.tensor(x)).item()
next_sample_2d = lambda cur_sample: dist.MultivariateNormal(
loc=cur_sample, covariance_matrix=torch.eye(len(cur_sample))
).sample()
x_2d = mh(log_p_2d, next_sample_2d, 5000, torch.tensor([0., 0.]))
sns.kdeplot(x_2d[:, 0], x_2d[:, 1], bw_adjust=4)
<AxesSubplot:>
data:image/s3,"s3://crabby-images/1c283/1c2836632f3b51ac4b8ac927e00f3037caea33ef" alt="../../_images/50885471f6f7264016b3bfaeb95b182584917d8999b651e9e4cdbecdd7a89d4d.png"
import corner
corner.corner(x_2d.numpy(), smooth=2, show_titles=True, labels=[r"$x_1$", "$x_2$"], smooth1d=1, );
data:image/s3,"s3://crabby-images/de281/de2819bac42b3a69511ab0eb48d5a5bae6810cba" alt="../../_images/dcd1ddce587ce58ead280aad13d4db2457dff3bdf861569ad3cee77f26b16ea6.png"
Linear Regression with 1 parameter#
y = 4*x + 0.5*torch.randn(5)
prior = dist.Normal(loc = 0., scale = 1.)
log_likelihood = lambda t: dist.Normal(loc = x*t, scale=1.).log_prob(y).sum(axis=0)
unnorm_post = lambda t:log_likelihood(torch.tensor(t)).item() - prior.log_prob(torch.tensor(t)).item()
xs = mh(unnorm_post, next_sample, 10000, 0.)
sns.kdeplot(xs)
<AxesSubplot:ylabel='Density'>
data:image/s3,"s3://crabby-images/b3e91/b3e91a12be2b77d4ad230f5f4e2ab6e75369d2e3" alt="../../_images/928c3b0e05b2f1121974b57401251a32e081dc69553d9431eb771ad8916d9fd8.png"
import seaborn as sns
sns.kdeplot(xs.numpy(), label='Samples obtained from MH')
plt.axvline(0., label='Prior mean', color='k', linestyle='--')
plt.axvline(4, label='True value', color='g', linestyle='-.')
plt.legend()
<matplotlib.legend.Legend at 0x176e6b880>
data:image/s3,"s3://crabby-images/fbb50/fbb50dd2d8494d28deb6f62b1683cad1c2cf889a" alt="../../_images/e99c27c4326b8ccc4238ad837e57580922d2723522ceea58ccd421d2d2687882.png"
plt.plot(xs)
[<matplotlib.lines.Line2D at 0x176ec3880>]
data:image/s3,"s3://crabby-images/e4da9/e4da9af9598f4b3a5f7e2b7f9790272bdc01ef00" alt="../../_images/f4f26a575874398d3b194ff0520301cb6e7caecbe9a46979a76d41f36431ea82.png"
for i in range(100):
plt.plot(x, xs[i]*x, alpha=0.1, color='k')
plt.scatter(x, y, zorder=10)
<matplotlib.collections.PathCollection at 0x1770017f0>
data:image/s3,"s3://crabby-images/24655/246559a4503d8678fde775f48dfdaffeaf7121cb" alt="../../_images/f0a4a767de9dda6ad0761b4aebc27963b229ea0905933a5a42055b1b054d2909.png"
xs_mean = xs.mean()
xs_std = xs.std()
plt.plot(x, xs_mean*x, color='k')
plt.scatter(x, y, zorder=10)
plt.fill_between(x, (xs_mean-2*xs_std)*x, (xs_mean+2*xs_std)*x, color='k', alpha=0.1)
<matplotlib.collections.PolyCollection at 0x177104eb0>
data:image/s3,"s3://crabby-images/b607f/b607f52a4b85e46cce72a33bb2c021ce8132461d" alt="../../_images/0f30f15bdcf53707f5d730da895c98aa397967ad0eb5dcf10a60c0f9d81d91f4.png"
Linear Regression with 2 parameter#
y = 4*x + 0.5*torch.randn(5)
prior = dist.MultivariateNormal(loc = torch.zeros([2]), covariance_matrix=torch.eye(2))
log_likelihood = lambda t: dist.Normal(loc = x*t[1] + t[0], scale=1.).log_prob(y).sum(axis=0)
unnorm_post = lambda t:log_likelihood(torch.tensor(t)).item() - prior.log_prob(torch.tensor(t)).item()
References