Rejection sampling#
Author: Nipun Batra
https://www.youtube.com/watch?v=kYWHfgkRc9s
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rc
from scipy.stats import expon
import seaborn as sns
%matplotlib inline
rc('font', size=16)
rc('text', usetex=True)
Exponential distribution#
Plotting pdf of exponential distribution
rv = expon()
x = np.linspace(0, 10, 1000)
plt.plot(x, rv.pdf(x), label='pdf');
plt.xlabel('x');
plt.legend();
data:image/s3,"s3://crabby-images/dd080/dd0803012400df1328e34bdf6ed012c7290deb16" alt="../../_images/98e541ffe682fd9b8dd35015c0230f7cdf59e6eb0904507df4f9f5e562683441.png"
Generating samples from uniform distribution
uni_samples = np.random.uniform(low=0, high=10, size=100)
sns.kdeplot(uni_samples, label='pdf');
plt.xlabel('x');
plt.legend();
data:image/s3,"s3://crabby-images/f2335/f233526188adbee761871b1929bf2de218db5a17" alt="../../_images/7ce657c84922bbeafe74da700bd8980d24a4ef6030eeeaf1fffaacf8780b2074.png"
uni_samples = np.random.uniform(low=0, high=10, size=100000)
sns.kdeplot(uni_samples, label='pdf');
plt.xlabel('x');
plt.legend();
data:image/s3,"s3://crabby-images/aa161/aa161b6f0bb364ff7fa1603bc1b45a407e665df7" alt="../../_images/df98c7a5442ab52eb8accda82cf2a3e1e4626779bf551c77fc32e1421a3fa37a.png"
We can accept all the samples that fall within the area underneath pdf.
x = np.linspace(0, 10, 1000)
plt.plot(x, rv.pdf(x),'k',lw=2)
samples_uniform_x = np.random.uniform(0, 10, 100000)
samples_uniform_y = np.random.uniform(0, 1, 100000)
pdfs = rv.pdf(samples_uniform_x)
idx = samples_uniform_y < pdfs
plt.scatter(samples_uniform_x[idx], samples_uniform_y[idx],alpha=0.3, color='green',s=0.1,label="Accepted")
plt.scatter(samples_uniform_x[~idx], samples_uniform_y[~idx],alpha=0.3, color='red',s=0.1,label="Rejected")
plt.legend(bbox_to_anchor=(1,1));
data:image/s3,"s3://crabby-images/38d66/38d66539989cdd10b4ec5f5bc346dec8f5c33206" alt="../../_images/2b6166f5932da54f4eb0bcd67acf418c27766c4fc5faf2f8091a7de67976bedc.png"
plt.hist(samples_uniform_x[idx], bins=100);
data:image/s3,"s3://crabby-images/edcb8/edcb89c80f5d3952a599db14bcc37844edd8b4bb" alt="../../_images/6fae03875ea965788b3317ae89e5a0c8c3599be521e92abb7bf6957d95609211.png"
We can define a general function to do the rejection sampling.
def rejection_sampling(pdf, lower_support, upper_support, samples=1000, y_max = 1):
#x = np.linspace(0, 10, 1000)
#plt.plot(x, pdf(x),'k',lw=2)
samples_uniform_x = np.random.uniform(lower_support, upper_support, samples)
samples_uniform_y = np.random.uniform(0, y_max, samples)
pdfs = pdf(samples_uniform_x)
idx = samples_uniform_y < pdfs
plt.scatter(samples_uniform_x[idx], samples_uniform_y[idx],alpha=0.6, color='green',s=0.1,label="Accepted")
plt.scatter(samples_uniform_x[~idx], samples_uniform_y[~idx],alpha=0.6, color='red',s=0.1,label="Rejected")
plt.title(f'mean = {samples_uniform_x[idx].mean()}')
plt.legend()
Normal distribution#
from scipy.stats import norm
scale =1
rv = norm(loc=0, scale=scale)
pdf = rv.pdf
rejection_sampling(pdf, -5, 5, 10000)
x = np.linspace(-5, 5, 1000)
plt.plot(x, pdf(x),'k',lw=2);
data:image/s3,"s3://crabby-images/fe1e2/fe1e2cb11cfd7657c86df6af6ada73194fbf6def" alt="../../_images/522b2a45f3251ab359af4b7593a471926c3295b9c11f40dbf954f9c11e59a575.png"
Let us try with lower value of standard deviation.
from scipy.stats import norm
scale =0.1
rv = norm(loc=0, scale=scale)
pdf = rv.pdf
rejection_sampling(pdf, -5, 5, 10000)
x = np.linspace(-5, 5, 1000)
plt.plot(x, pdf(x),'k',lw=2);
data:image/s3,"s3://crabby-images/65e5e/65e5eb91eee90776f6e050b1ad6c45c67acd1cb9" alt="../../_images/01309329398e6cfe89478f3feb6dd54d918f8e652add8cd44c859af7ea53b268.png"
We need to increase the space of sampling in this case.
scale =0.1
rv = norm(loc=1, scale=scale)
pdf = rv.pdf
rejection_sampling(pdf, -5, 5, 50000,y_max=(1/scale)/(np.sqrt(2*np.pi)))
x = np.linspace(-5, 5, 1000)
plt.plot(x, pdf(x),'k',lw=2);
data:image/s3,"s3://crabby-images/9dcd5/9dcd53f3e4c1b799b78c7381a077fd6a90633658" alt="../../_images/a6b93091091d02a4932222755806e11ca69246276e6b0dfb13ceaaaa8fbc766d.png"
Gamma distribution#
from scipy.stats import gamma
rv = gamma(1)
pdf = rv.pdf
rejection_sampling(pdf, 0, 5, 10000)
x = np.linspace(0, 5, 1000)
plt.plot(x, pdf(x),'k',lw=2);
data:image/s3,"s3://crabby-images/662b2/662b2b29d4480e527e6646dd273da68fc98b5535" alt="../../_images/d564510dd803b817d577d8d5e0b9f6e42e8c45edeecb6cb4b0bb20b4244e0cfa.png"
rv = gamma(2)
pdf = rv.pdf
rejection_sampling(pdf, 0, 5, 10000)
x = np.linspace(0, 5, 1000)
plt.plot(x, pdf(x),'k',lw=2);
data:image/s3,"s3://crabby-images/a693c/a693c1368a130faf7c0a06d8f7e953fef2bd3440" alt="../../_images/51bc60406ee8429175d8b2596191ee845f3a8aa4e48af06a67d2304b77ecd42a.png"
rv = gamma(10)
pdf = rv.pdf
rejection_sampling(pdf, 0, 10, 10000)
x = np.linspace(0, 10, 1000)
plt.plot(x, pdf(x),'k',lw=2);
data:image/s3,"s3://crabby-images/eb159/eb1593c9a5ba1730421d18a92cc1fc2fecea8600" alt="../../_images/9b88d8956aa9877dd56ee134d4a70025599c9fc28f47fbeb77126e0db4eea5d1.png"
Beta distribution#
from scipy.stats import beta
rv = beta(a=4.5, b=5)
pdf = rv.pdf
rejection_sampling(pdf, 0, 5, 10000, y_max=2.5)
x = np.linspace(0, 5, 1000)
plt.plot(x, pdf(x),'k',lw=2);
data:image/s3,"s3://crabby-images/c507e/c507e4e911a3cb5c8464eaa6a728acbaedc8f85c" alt="../../_images/9d8c863b63f131919d26cf8fae2fdb035c0d3ce893598e7fbc2fc309647518b6.png"