import torch
dist = torch.distributions
import matplotlib.pyplot as plt
import pandas as pd
%matplotlib inline
Simple rejection sampling#
mix = dist.MixtureSameFamily(
mixture_distribution=dist.Categorical(torch.tensor([0.2, 0.8])),
component_distribution=dist.Normal(torch.tensor([0.1, 3]), torch.tensor([0.6, 3])),
)
x = torch.linspace(-4., 4, 5)
theta_range = torch.linspace(-8, 12, 100)
plt.plot(theta_range, mix.log_prob(theta_range).exp())
[<matplotlib.lines.Line2D at 0x14871f760>]
data:image/s3,"s3://crabby-images/89451/894511674a120a80745b37c3954cb759776d7030" alt="../../_images/23a7e888ad7f21789d727a58e9e75747062e4171a10ada421f90d73fca94731b.png"
mix.mean, mix.stddev
(tensor(2.4200), tensor(2.9356))
q = dist.Normal(loc = 0., scale=1.)
plt.plot(theta_range, q.log_prob(theta_range).exp())
[<matplotlib.lines.Line2D at 0x148788e20>]
data:image/s3,"s3://crabby-images/abbb8/abbb84a75b228bf014851660d12ffee57fa19aa5" alt="../../_images/f30781a61522a0708e08f7c4d40ad8e56810ad0a59efe8bdefbf9aeceb46e9ad.png"
plt.plot(theta_range, mix.log_prob(theta_range).exp(), label='p_tilde')
plt.plot(theta_range, q.log_prob(theta_range).exp(), label='q')
plt.legend()
<matplotlib.legend.Legend at 0x1487f25e0>
data:image/s3,"s3://crabby-images/0def0/0def0708636261a09779ef37d9e01ae27ea2990f" alt="../../_images/ba4eb5c6992f150a432dd6f3844abc7a6b89d019d53ba22466dd81ee4776a561.png"
Choosing proposal q with high variance
q = dist.Normal(loc = 0., scale=5.)
plt.plot(theta_range, mix.log_prob(theta_range).exp(), label='p_tilde')
plt.plot(theta_range, q.log_prob(theta_range).exp(), label='q')
plt.legend()
<matplotlib.legend.Legend at 0x147ce8a00>
data:image/s3,"s3://crabby-images/d3632/d36321ffd6836a2b676af9165ca665124ac6b88d" alt="../../_images/2af7765363f331ad731e22515e5be97d54f394741659b74dc849892a383f502d.png"
Scaling q so that cq(x) >= p(x) for all x
theta_more = theta_range[mix.log_prob(theta_range).exp() > q.log_prob(theta_range).exp()]
c = (mix.log_prob(theta_more).exp()/q.log_prob(theta_more).exp()).max().item()
q = dist.Normal(loc = 0., scale=5.)
plt.plot(theta_range, mix.log_prob(theta_range).exp(), label='p_tilde')
plt.plot(theta_range, c*q.log_prob(theta_range).exp(), label='c*q')
plt.legend()
<matplotlib.legend.Legend at 0x148b6d610>
data:image/s3,"s3://crabby-images/03943/0394393cd6658ceafd8f256d927937d47b292953" alt="../../_images/512cdee73eb62a16b0de288e8d98764de076ccc7a6aa50652777e1ac8175c330.png"
plt.gca().get_xlim()
(0.0, 1.0)
data:image/s3,"s3://crabby-images/0da58/0da58f6610f429d7182af3fe3c43460d22b84c82" alt="../../_images/acf8b18dcf2c5918dca86ee7dd477a1e58f3bc043cdd9a261c9cab3c24addebd.png"
def plot_rejection(ax=None):
if ax is None:
fig, ax = plt.subplots()
s = q.sample()
ax.plot(theta_range, mix.log_prob(theta_range).exp(), label='p_tilde')
ax.plot(theta_range, c*q.log_prob(theta_range).exp(), label='c*q')
ax.scatter([s], 0, label='s (sample from q)', s=100,color='k')
cqs = c*q.log_prob(s).exp().item()
ax.axhline(cqs, label='c*q(s)', color='k')
ax.plot([s, s], [0, cqs], color='k', linestyle='-.')
u = dist.Uniform(0, cqs).sample()
if u > mix.log_prob(s).exp():
ax.scatter([s], [u], label='u', s=100,color='r')
else:
ax.scatter([s], [u], label='u', s=100,color='g')
ax.legend(bbox_to_anchor=(1.04,1), borderaxespad=0)
plot_rejection()
data:image/s3,"s3://crabby-images/e9515/e95152080245cc49e2e7ba6cf9318666d6d87b39" alt="../../_images/7eafde5ad52ac5eb202befaf70c75ee1e9711452f5461283f69968e6757e3904.png"
fig, ax = plt.subplots(nrows=4, ncols=4, figsize=(20, 20))
for i in range(16):
plot_rejection(ax[i//4, i%4])
fig.tight_layout()
data:image/s3,"s3://crabby-images/6c72f/6c72fa56bdae1d652febb81f4b73547f2c73472e" alt="../../_images/c139d023db5605694414279b6f42af5c3cc9d3fdd958a7a27ec69f49da22671d.png"
N = 5000
s = q.sample([N])
fig, ax = plt.subplots()
ax.plot(theta_range, mix.log_prob(theta_range).exp(), label='p_tilde')
ax.plot(theta_range, c*q.log_prob(theta_range).exp(), label='c*q')
ax.scatter(s, torch.zeros_like(s), label='s (sample from q)', s=5,color='k', marker='|')
cqs = c*q.log_prob(s).exp()
u = dist.Uniform(0, cqs).sample()
gt_index = u > mix.log_prob(s).exp()
gt = u[gt_index]
s_gt = s[gt_index]
s_lt = s[~gt_index]
lt = u[~gt_index]
ax.scatter(s_gt, gt, s=1,color='r', label='Rejected')
ax.scatter(s_lt, lt, label='Accepted', s=1,color='g')
ax.legend()
<matplotlib.legend.Legend at 0x14c6efb20>
data:image/s3,"s3://crabby-images/8121a/8121a51ad3f789558460f0d528297268b69354dd" alt="../../_images/dceb0dc37cac94a96ea50d1aee415e77b3dfc332175969003c1c8e7a2494534d.png"
sns.kdeplot(s_lt)
sns.kdeplot(mix.sample([1000]))
<AxesSubplot:ylabel='Density'>
data:image/s3,"s3://crabby-images/775e8/775e8b3c5a9967e0c01b15c87723f1ab98a52fe8" alt="../../_images/26aaaa79b58f4cd61bdfc007c96250930bb7a7ffea3d776cbb08417c51faeb2a.png"
Linear Regression#
y = 4*x + torch.randn(5)
plt.scatter(x, y)
<matplotlib.collections.PathCollection at 0x146cd6f40>
data:image/s3,"s3://crabby-images/807e0/807e09d1aeab89b40030e3ae44e54a0d46f1474d" alt="../../_images/95125a3b9da95f7f5e36494611da3a92459fa7d9673a011732482caf8f37c945.png"
prior = dist.Normal(loc = 0., scale = 1.)
likelihood = lambda t: dist.Normal(loc = x*t, scale=1.).log_prob(y).sum(axis=0).exp().item()
log_likelihood = lambda t: dist.Normal(loc = x*t, scale=1.).log_prob(y).sum(axis=0)
lls = {}
ls = {}
theta_range = torch.linspace(-1, 7, 100)
for t in theta_range:
lls[t.item()] = log_likelihood(t).item()
ls[t.item()] = log_likelihood(t).exp().item()
pd.Series(ls).plot()
<AxesSubplot:>
data:image/s3,"s3://crabby-images/c59f1/c59f1b83a750b7058c307b820b6e4e42442983fe" alt="../../_images/9898408e81838e8e4af469fc9147a736e4c3b82483411c2b83f1ce11b9890dbc.png"
pd.Series(lls).plot()
<AxesSubplot:>
data:image/s3,"s3://crabby-images/b2700/b2700595fb6bd90dc4ed224842f543300accb41d" alt="../../_images/930cfb9894618968dc09f689d7cef8bf3601964654381af7b88bc1b3e8b5de2e.png"
plt.plot(theta_range, prior.log_prob(theta_range).exp())
[<matplotlib.lines.Line2D at 0x146e26700>]
data:image/s3,"s3://crabby-images/65ae6/65ae6112cea2c6f6e8c649e83585d1ab055acb26" alt="../../_images/0eb392252d42e66e003e3cebe2d76c25c2f5682111bbf2dd03f60a422825a140.png"
unnorm_post = lambda t:log_likelihood(t).item() + prior.log_prob(t).item()
up = {}
for t in theta_range:
up[t.item()] = unnorm_post(t)
fig, ax = plt.subplots(nrows=3, sharex=True, figsize= (8, 6))
pd.Series(up).plot(ax=ax[0], title="Unnorm Posterior")
pd.Series(lls).plot(ax=ax[1], title='Loglik')
ax[2].plot(theta_range, prior.log_prob(theta_range))
[<matplotlib.lines.Line2D at 0x146e891c0>]
data:image/s3,"s3://crabby-images/2ee41/2ee41a917f835113bb0e19ec404fad01a3b57565" alt="../../_images/39c205f88c9c1ae6c42504fa0a641717c24098c83eb3810b30c8f89fa0b76e2d.png"
torch.exp(torch.tensor(pd.Series(up).values))
tensor([7.6395e-249, 1.9184e-241, 3.6856e-234, 5.4180e-227, 6.0933e-220,
5.2437e-213, 3.4526e-206, 1.7393e-199, 6.7039e-193, 1.9769e-186,
4.4607e-180, 7.7007e-174, 1.0172e-167, 1.0280e-161, 7.9493e-156,
4.7028e-150, 2.1288e-144, 7.3728e-139, 1.9538e-133, 3.9613e-128,
6.1447e-123, 7.2932e-118, 6.6227e-113, 4.6016e-108, 2.4463e-103,
9.9501e-99, 3.0965e-94, 7.3730e-90, 1.3432e-85, 1.8722e-81,
1.9967e-77, 1.6293e-73, 1.0172e-69, 4.8588e-66, 1.7758e-62,
4.9658e-59, 1.0625e-55, 1.7392e-52, 2.1783e-49, 2.0875e-46,
1.5306e-43, 8.5864e-41, 3.6854e-38, 1.2103e-35, 3.0411e-33,
5.8465e-31, 8.5998e-29, 9.6786e-27, 8.3341e-25, 5.4907e-23,
2.7677e-21, 1.0675e-19, 3.1500e-18, 7.1120e-17, 1.2286e-15,
1.6238e-14, 1.6421e-13, 1.2706e-12, 7.5216e-12, 3.4069e-11,
1.1807e-10, 3.1306e-10, 6.3512e-10, 9.8585e-10, 1.1708e-09,
1.0639e-09, 7.3967e-10, 3.9346e-10, 1.6014e-10, 4.9867e-11,
1.1881e-11, 2.1659e-12, 3.0209e-13, 3.2237e-14, 2.6322e-15,
1.6444e-16, 7.8598e-18, 2.8744e-19, 8.0428e-21, 1.7219e-22,
2.8205e-24, 3.5348e-26, 3.3896e-28, 2.4869e-30, 1.3960e-32,
5.9956e-35, 1.9702e-37, 4.9536e-40, 9.5292e-43, 1.4026e-45,
1.5795e-48, 1.3609e-51, 8.9718e-55, 4.5254e-58, 1.7464e-61,
5.1568e-65, 1.1650e-68, 2.0139e-72, 2.6634e-76, 2.6951e-80],
dtype=torch.float64)
q = dist.Normal(loc = 3., scale=.4)
c = 1e-1
q.log_prob(theta_range) + torch.log(torch.tensor(c))
tensor([-52.3052, -50.3054, -48.3465, -46.4283, -44.5509, -42.7144, -40.9186,
-39.1637, -37.4496, -35.7763, -34.1438, -32.5521, -31.0013, -29.4912,
-28.0220, -26.5936, -25.2060, -23.8592, -22.5532, -21.2880, -20.0636,
-18.8801, -17.7373, -16.6354, -15.5743, -14.5540, -13.5745, -12.6358,
-11.7379, -10.8809, -10.0646, -9.2892, -8.5546, -7.8608, -7.2078,
-6.5956, -6.0242, -5.4937, -5.0039, -4.5550, -4.1469, -3.7796,
-3.4531, -3.1674, -2.9225, -2.7185, -2.5552, -2.4328, -2.3511,
-2.3103, -2.3103, -2.3511, -2.4328, -2.5552, -2.7185, -2.9225,
-3.1674, -3.4531, -3.7796, -4.1469, -4.5550, -5.0039, -5.4937,
-6.0242, -6.5956, -7.2078, -7.8608, -8.5546, -9.2892, -10.0646,
-10.8809, -11.7379, -12.6358, -13.5745, -14.5540, -15.5743, -16.6354,
-17.7373, -18.8801, -20.0636, -21.2880, -22.5532, -23.8592, -25.2059,
-26.5936, -28.0220, -29.4912, -31.0013, -32.5521, -34.1438, -35.7763,
-37.4496, -39.1637, -40.9186, -42.7144, -44.5509, -46.4283, -48.3464,
-50.3054, -52.3052])
plt.plot(theta_range, q.log_prob(theta_range)+torch.log(torch.tensor(c)), label='q(theta)')
pd.Series(up).plot(ax=plt.gca(), label='Unnorm Posterior')
plt.legend()
<matplotlib.legend.Legend at 0x1471fb7c0>
data:image/s3,"s3://crabby-images/9298a/9298a37a09a4426421e981725c89168abd9efc9a" alt="../../_images/78080c13e381833c7ffc9530228e58961c1dafb848fbf73fcba9f9226a1b6da1.png"
plt.plot(theta_range, q.log_prob(theta_range)+torch.log(torch.tensor(c)), label='q(theta)')
[<matplotlib.lines.Line2D at 0x14727a940>]
data:image/s3,"s3://crabby-images/3fe26/3fe26516493973520574a5fefe159a0b74b822b6" alt="../../_images/bfc035f83fb8436902d48aa7cf21e1ddc46b64f50b5afb3cb93b1f9da083b819.png"
s = q.sample()
s
tensor(3.0192)
q_s = q.log_prob(s).exp()
q_s
tensor(0.9962)
u = dist.Uniform(low=0., high=q_s).sample()
u
tensor(0.8261)
log_q_s = q.log_prob(s).item()
log_p_bar_s = unnorm_post(s)
log_u = torch.log(u).item()
log_q_s, log_p_bar_s, log_u
(-0.0037949681282043457, -48.340909004211426, -0.1910155713558197)
if log_u > log_p_bar_s:
pass
considered = []
for i in range(10000):
s = q.sample()
q_s = q.log_prob(s).exp()
u = c*dist.Uniform(low=0., high=q_s).sample()
log_q_s = q.log_prob(s).item()
log_p_bar_s = unnorm_post(s)
log_u = torch.log(u).item()
if log_u > log_p_bar_s:
pass
else:
considered.append(u.item())
import seaborn as sns
sns.kdeplot(considered)
<AxesSubplot:>
data:image/s3,"s3://crabby-images/0da58/0da58f6610f429d7182af3fe3c43460d22b84c82" alt="../../_images/acf8b18dcf2c5918dca86ee7dd477a1e58f3bc043cdd9a261c9cab3c24addebd.png"