Rejection sampling#

Author: Nipun Batra

https://www.youtube.com/watch?v=kYWHfgkRc9s

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rc
from scipy.stats import expon
import seaborn as sns
%matplotlib inline

rc('font', size=16)
rc('text', usetex=True)

Exponential distribution#

Plotting pdf of exponential distribution

rv = expon()
x = np.linspace(0, 10, 1000)
plt.plot(x, rv.pdf(x), label='pdf');
plt.xlabel('x');
plt.legend();
../../_images/98e541ffe682fd9b8dd35015c0230f7cdf59e6eb0904507df4f9f5e562683441.png

Generating samples from uniform distribution

uni_samples = np.random.uniform(low=0, high=10, size=100)

sns.kdeplot(uni_samples, label='pdf');
plt.xlabel('x');
plt.legend();
../../_images/7ce657c84922bbeafe74da700bd8980d24a4ef6030eeeaf1fffaacf8780b2074.png
uni_samples = np.random.uniform(low=0, high=10, size=100000)

sns.kdeplot(uni_samples, label='pdf');
plt.xlabel('x');
plt.legend();
../../_images/df98c7a5442ab52eb8accda82cf2a3e1e4626779bf551c77fc32e1421a3fa37a.png

We can accept all the samples that fall within the area underneath pdf.

x = np.linspace(0, 10, 1000)
plt.plot(x, rv.pdf(x),'k',lw=2)

samples_uniform_x = np.random.uniform(0, 10, 100000)
samples_uniform_y = np.random.uniform(0, 1, 100000)


pdfs = rv.pdf(samples_uniform_x)

idx = samples_uniform_y < pdfs

plt.scatter(samples_uniform_x[idx], samples_uniform_y[idx],alpha=0.3, color='green',s=0.1,label="Accepted")
plt.scatter(samples_uniform_x[~idx], samples_uniform_y[~idx],alpha=0.3, color='red',s=0.1,label="Rejected")
plt.legend(bbox_to_anchor=(1,1));
../../_images/2b6166f5932da54f4eb0bcd67acf418c27766c4fc5faf2f8091a7de67976bedc.png
plt.hist(samples_uniform_x[idx], bins=100);
../../_images/6fae03875ea965788b3317ae89e5a0c8c3599be521e92abb7bf6957d95609211.png

We can define a general function to do the rejection sampling.

def rejection_sampling(pdf, lower_support, upper_support, samples=1000, y_max = 1):
    #x = np.linspace(0, 10, 1000)
    #plt.plot(x, pdf(x),'k',lw=2)

    samples_uniform_x = np.random.uniform(lower_support, upper_support, samples)
    samples_uniform_y = np.random.uniform(0, y_max, samples)


    pdfs = pdf(samples_uniform_x)

    idx = samples_uniform_y < pdfs

    plt.scatter(samples_uniform_x[idx], samples_uniform_y[idx],alpha=0.6, color='green',s=0.1,label="Accepted")
    plt.scatter(samples_uniform_x[~idx], samples_uniform_y[~idx],alpha=0.6, color='red',s=0.1,label="Rejected")
    plt.title(f'mean = {samples_uniform_x[idx].mean()}')
    plt.legend()

Normal distribution#

from scipy.stats import norm
scale =1
rv = norm(loc=0, scale=scale)
pdf = rv.pdf
rejection_sampling(pdf, -5, 5, 10000)
x = np.linspace(-5, 5, 1000)
plt.plot(x, pdf(x),'k',lw=2);
../../_images/522b2a45f3251ab359af4b7593a471926c3295b9c11f40dbf954f9c11e59a575.png

Let us try with lower value of standard deviation.

from scipy.stats import norm
scale =0.1
rv = norm(loc=0, scale=scale)
pdf = rv.pdf
rejection_sampling(pdf, -5, 5, 10000)
x = np.linspace(-5, 5, 1000)
plt.plot(x, pdf(x),'k',lw=2);
../../_images/01309329398e6cfe89478f3feb6dd54d918f8e652add8cd44c859af7ea53b268.png

We need to increase the space of sampling in this case.

scale =0.1
rv = norm(loc=1, scale=scale)
pdf = rv.pdf
rejection_sampling(pdf, -5, 5, 50000,y_max=(1/scale)/(np.sqrt(2*np.pi)))
x = np.linspace(-5, 5, 1000)
plt.plot(x, pdf(x),'k',lw=2);
../../_images/a6b93091091d02a4932222755806e11ca69246276e6b0dfb13ceaaaa8fbc766d.png

Gamma distribution#

from scipy.stats import gamma
rv = gamma(1)
pdf = rv.pdf
rejection_sampling(pdf, 0, 5, 10000)
x = np.linspace(0, 5, 1000)
plt.plot(x, pdf(x),'k',lw=2);
../../_images/d564510dd803b817d577d8d5e0b9f6e42e8c45edeecb6cb4b0bb20b4244e0cfa.png
rv = gamma(2)
pdf = rv.pdf
rejection_sampling(pdf, 0, 5, 10000)
x = np.linspace(0, 5, 1000)
plt.plot(x, pdf(x),'k',lw=2);
../../_images/51bc60406ee8429175d8b2596191ee845f3a8aa4e48af06a67d2304b77ecd42a.png
rv = gamma(10)
pdf = rv.pdf
rejection_sampling(pdf, 0, 10, 10000)
x = np.linspace(0, 10, 1000)
plt.plot(x, pdf(x),'k',lw=2);
../../_images/9b88d8956aa9877dd56ee134d4a70025599c9fc28f47fbeb77126e0db4eea5d1.png

Beta distribution#

from scipy.stats import beta
rv = beta(a=4.5, b=5)
pdf = rv.pdf
rejection_sampling(pdf, 0, 5, 10000, y_max=2.5)
x = np.linspace(0, 5, 1000)
plt.plot(x, pdf(x),'k',lw=2);
../../_images/9d8c863b63f131919d26cf8fae2fdb035c0d3ce893598e7fbc2fc309647518b6.png