Probabilstic PCA using PyTorch distributions
Contents
Probabilstic PCA using PyTorch distributions¶
Basic Imports¶
import numpy as np
import matplotlib.pyplot as plt
import torch
import seaborn as sns
import pandas as pd
dist =torch.distributions
sns.reset_defaults()
sns.set_context(context="talk", font_scale=1)
%matplotlib inline
%config InlineBackend.figure_format='retina'
Generative model for PPCA in PyTorch¶
data_dim = 2
latent_dim = 1
num_datapoints = 100
z = dist.Normal(
loc=torch.zeros([latent_dim, num_datapoints]),
scale=torch.ones([latent_dim, num_datapoints]),)
w = dist.Normal(
loc=torch.zeros([data_dim, latent_dim]),
scale=5.0 * torch.ones([data_dim, latent_dim]),
)
w_sample= w.sample()
z_sample = z.sample()
x = dist.Normal(loc = w_sample@z_sample, scale=1)
x_sample = x.sample([100])
plt.scatter(x_sample[:, 0], x_sample[:, 1], alpha=0.2, s=30)
<matplotlib.collections.PathCollection at 0x16051f1c0>
Generative model for PPCA in Pyro¶
import pyro.distributions as dist
import pyro.distributions.constraints as constraints
import pyro
pyro.clear_param_store()
def ppca_model(data, latent_dim):
N, data_dim = data.shape
W = pyro.sample(
"W",
dist.Normal(
loc=torch.zeros([latent_dim, data_dim]),
scale=5.0 * torch.ones([latent_dim, data_dim]),
),
)
Z = pyro.sample(
"Z",
dist.Normal(
loc=torch.zeros([N, latent_dim]),
scale=torch.ones([N, latent_dim]),
),
)
mean = Z @ W
return pyro.sample("obs", pyro.distributions.Normal(mean, 1.0), obs=data)
pyro.render_model(
ppca_model, model_args=(torch.randn(150, 2), 1), render_distributions=True
)
ppca_model(x_sample[0], 3).shape
torch.Size([2, 100])
from pyro import poutine
with pyro.plate("samples", 10, dim=-3):
trace = poutine.trace(ppca_model).get_trace(x_sample[0], 1)
trace.nodes['W']['value'].squeeze()
torch.Size([10, 100])
data_dim = 3
latent_dim = 2
W = pyro.sample(
"W",
dist.Normal(
loc=torch.zeros([latent_dim, data_dim]),
scale=5.0 * torch.ones([latent_dim, data_dim]),
),
)
N = 150
Z = pyro.sample(
"Z",
dist.Normal(
loc=torch.zeros([N, latent_dim]),
scale=torch.ones([N, latent_dim]),
),
)
Z.shape, W.shape
(torch.Size([150, 2]), torch.Size([2, 3]))
(Z@W).shape
torch.Size([150, 3])