Linear Regression using Pyro¶

Basic Imports¶

import numpy as np
import matplotlib.pyplot as plt
import torch
import seaborn as sns
import pandas as pd

t_dist =torch.distributions

sns.reset_defaults()
sns.set_context(context="talk", font_scale=1)
%matplotlib inline
%config InlineBackend.figure_format='retina'

Creating dataset¶

x = torch.linspace(-5, 5, 200)
true_y = 3*x + 4 
observed_y = true_y+ 2*torch.randn(200)

plt.scatter(x, observed_y, s = 30, alpha=0.5)
plt.plot(x, true_y, color = 'k')


sns.despine()
../../_images/2022_02_17_pyro_linreg_4_0.png

MLE¶

import pyro
import pyro.distributions as dist
import pyro.distributions.constraints as constraints

pyro.clear_param_store()


def mle_model(x, y=None):
    theta_0 = pyro.param("theta_0", torch.randn(1))
    theta_1 = pyro.param("theta_1", torch.randn(1))
    y_hat_mean = theta_0 + theta_1*x

    with pyro.plate("data", len(x)):
        return pyro.sample("obs", dist.Normal(y_hat_mean, 1), obs=y)

#pyro.render_model(mle_model, model_args=(x, observed_y))
m = mle_model(x)
pyro.param("theta_0").item(), pyro.param("theta_1").item()
(-0.100727379322052, 1.5172470808029175)
for i in range(5):
    plt.scatter(x, mle_model(x).detach(), s = 5, alpha = 0.4)
    plt.plot(x, )
../../_images/2022_02_17_pyro_linreg_8_0.png
def guide(x, y):
    # register the two variational parameters with Pyro.
    pyro.sample("theta_0", dist.Normal(0., 1.))
    pyro.sample("theta_1", dist.Normal(0., 1.))
pyro.render_model(guide, model_args=(x, observed_y))
../../_images/2022_02_17_pyro_linreg_10_0.svg
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO
adam_params = {"lr": 0.005, "betas": (0.90, 0.999)}
optimizer = Adam(adam_params)

# setup the inference algorithm
svi = SVI(mle_model, guide, optimizer, loss=Trace_ELBO())

n_steps = 5000
# do gradient steps
for step in range(n_steps):
    svi.step(x, y)
    if step%110==0:
        print(pyro.param("theta_0").item(), pyro.param("theta_1").item())
/Users/nipun/miniforge3/lib/python3.9/site-packages/pyro/util.py:288: UserWarning: Found non-auxiliary vars in guide but not model, consider marking these infer={'is_auxiliary': True}:
{'theta_1', 'theta_0'}
  warnings.warn(
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
File ~/miniforge3/lib/python3.9/site-packages/pyro/poutine/trace_struct.py:230, in Trace.compute_log_prob(self, site_filter)
    229 try:
--> 230     log_p = site["fn"].log_prob(
    231         site["value"], *site["args"], **site["kwargs"]
    232     )
    233 except ValueError as e:

File ~/miniforge3/lib/python3.9/site-packages/torch/distributions/normal.py:73, in Normal.log_prob(self, value)
     72 if self._validate_args:
---> 73     self._validate_sample(value)
     74 # compute the variance

File ~/miniforge3/lib/python3.9/site-packages/torch/distributions/distribution.py:276, in Distribution._validate_sample(self, value)
    275     if i != 1 and j != 1 and i != j:
--> 276         raise ValueError('Value is not broadcastable with batch_shape+event_shape: {} vs {}.'.
    277                          format(actual_shape, expected_shape))
    278 try:

ValueError: Value is not broadcastable with batch_shape+event_shape: torch.Size([100]) vs torch.Size([200]).

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
Input In [117], in <module>
     10 # do gradient steps
     11 for step in range(n_steps):
---> 12     svi.step(x, y)
     13     if step%110==0:
     14         print(pyro.param("theta_0").item(), pyro.param("theta_1").item())

File ~/miniforge3/lib/python3.9/site-packages/pyro/infer/svi.py:145, in SVI.step(self, *args, **kwargs)
    143 # get loss and compute gradients
    144 with poutine.trace(param_only=True) as param_capture:
--> 145     loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
    147 params = set(
    148     site["value"].unconstrained() for site in param_capture.trace.nodes.values()
    149 )
    151 # actually perform gradient steps
    152 # torch.optim objects gets instantiated for any params that haven't been seen yet

File ~/miniforge3/lib/python3.9/site-packages/pyro/infer/trace_elbo.py:140, in Trace_ELBO.loss_and_grads(self, model, guide, *args, **kwargs)
    138 loss = 0.0
    139 # grab a trace from the generator
--> 140 for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
    141     loss_particle, surrogate_loss_particle = self._differentiable_loss_particle(
    142         model_trace, guide_trace
    143     )
    144     loss += loss_particle / self.num_particles

File ~/miniforge3/lib/python3.9/site-packages/pyro/infer/elbo.py:182, in ELBO._get_traces(self, model, guide, args, kwargs)
    180 else:
    181     for i in range(self.num_particles):
--> 182         yield self._get_trace(model, guide, args, kwargs)

File ~/miniforge3/lib/python3.9/site-packages/pyro/infer/trace_elbo.py:57, in Trace_ELBO._get_trace(self, model, guide, args, kwargs)
     52 def _get_trace(self, model, guide, args, kwargs):
     53     """
     54     Returns a single trace from the guide, and the model that is run
     55     against it.
     56     """
---> 57     model_trace, guide_trace = get_importance_trace(
     58         "flat", self.max_plate_nesting, model, guide, args, kwargs
     59     )
     60     if is_validation_enabled():
     61         check_if_enumerated(guide_trace)

File ~/miniforge3/lib/python3.9/site-packages/pyro/infer/enum.py:75, in get_importance_trace(graph_type, max_plate_nesting, model, guide, args, kwargs, detach)
     72 guide_trace = prune_subsample_sites(guide_trace)
     73 model_trace = prune_subsample_sites(model_trace)
---> 75 model_trace.compute_log_prob()
     76 guide_trace.compute_score_parts()
     77 if is_validation_enabled():

File ~/miniforge3/lib/python3.9/site-packages/pyro/poutine/trace_struct.py:236, in Trace.compute_log_prob(self, site_filter)
    234     _, exc_value, traceback = sys.exc_info()
    235     shapes = self.format_shapes(last_site=site["name"])
--> 236     raise ValueError(
    237         "Error while computing log_prob at site '{}':\n{}\n{}".format(
    238             name, exc_value, shapes
    239         )
    240     ).with_traceback(traceback) from e
    241 site["unscaled_log_prob"] = log_p
    242 log_p = scale_and_mask(log_p, site["scale"], site["mask"])

File ~/miniforge3/lib/python3.9/site-packages/pyro/poutine/trace_struct.py:230, in Trace.compute_log_prob(self, site_filter)
    228 if "log_prob" not in site:
    229     try:
--> 230         log_p = site["fn"].log_prob(
    231             site["value"], *site["args"], **site["kwargs"]
    232         )
    233     except ValueError as e:
    234         _, exc_value, traceback = sys.exc_info()

File ~/miniforge3/lib/python3.9/site-packages/torch/distributions/normal.py:73, in Normal.log_prob(self, value)
     71 def log_prob(self, value):
     72     if self._validate_args:
---> 73         self._validate_sample(value)
     74     # compute the variance
     75     var = (self.scale ** 2)

File ~/miniforge3/lib/python3.9/site-packages/torch/distributions/distribution.py:276, in Distribution._validate_sample(self, value)
    274 for i, j in zip(reversed(actual_shape), reversed(expected_shape)):
    275     if i != 1 and j != 1 and i != j:
--> 276         raise ValueError('Value is not broadcastable with batch_shape+event_shape: {} vs {}.'.
    277                          format(actual_shape, expected_shape))
    278 try:
    279     support = self.support

ValueError: Error while computing log_prob at site 'obs':
Value is not broadcastable with batch_shape+event_shape: torch.Size([100]) vs torch.Size([200]).
Trace Shapes:      
 Param Sites:      
      theta_0     1
      theta_1     1
Sample Sites:      
     obs dist 200 |
        value 100 |
x = torch.linspace(-5, 5, 100)
predicted_y = pyro.param("theta_1").item()*x + pyro.param("theta_0").item()
#observed_y = true_y+ 2*torch.randn(100)

plt.scatter(x, observed_y, s = 30, alpha=0.5)
plt.plot(x, predicted_y, color = 'k')
plt.plot(x, true_y, color = 'k')



sns.despine()
../../_images/2022_02_17_pyro_linreg_12_0.png
data_dim = 2
latent_dim = 1
num_datapoints = 100
z = dist.Normal(
    loc=torch.zeros([latent_dim, num_datapoints]),
    scale=torch.ones([latent_dim, num_datapoints]),)

w = dist.Normal(
    loc=torch.zeros([data_dim, latent_dim]),
    scale=5.0 * torch.ones([data_dim, latent_dim]),
)
w_sample= w.sample()
z_sample = z.sample()


x = dist.Normal(loc = w_sample@z_sample, scale=2)
x_sample = x.sample([100])
plt.scatter(x_sample[:, 0], x_sample[:, 1], alpha=0.2, s=30)
<matplotlib.collections.PathCollection at 0x135cdf700>
../../_images/2022_02_17_pyro_linreg_14_1.png

Generative model for PPCA in Pyro¶

import pyro.distributions as dist
import pyro.distributions.constraints as constraints
import pyro
pyro.clear_param_store()

def ppca_model(data, latent_dim):
    N, data_dim = data.shape
    W = pyro.param("W", torch.zeros((data_dim, latent_dim)))
    #print(W.shape, data_dim, (data_dim, latent_dim))
    
    for i in range(N):
        z_vec = pyro.sample("z_{}".format(i), dist.Normal(loc = torch.zeros(latent_dim), scale = 1.))
        #print(W.shape, z.shape, W@z)
        pyro.sample(fr"\$x_{i}\$", dist.Normal(W@z_vec, 2.), obs=data[i])


pyro.render_model(ppca_model, model_args=(torch.randn(150, 3), 1))
../../_images/2022_02_17_pyro_linreg_16_0.svg
pyro.clear_param_store()

def ppca_model2(data, latent_dim):
    N, data_dim = data.shape
    W = pyro.param("W", torch.zeros((data_dim, latent_dim)))
    #print(W.shape, data_dim, (data_dim, latent_dim))
    z_vec = pyro.sample("z", dist.Normal(loc = torch.zeros([latent_dim, N]), scale = 1.))
    
    print(W.shape, z_vec.shape, (W@z_vec).t().shape, data.shape)
    return pyro.sample("obs", (W@z_vec).t(), obs=data)
    
pyro.render_model(ppca_model2, model_args=(torch.randn(150, 3), 1))
torch.Size([3, 1]) torch.Size([1, 150]) torch.Size([150, 3]) torch.Size([150, 3])
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Input In [110], in <module>
      9     print(W.shape, z_vec.shape, (W@z_vec).t().shape, data.shape)
     10     return pyro.sample("obs", (W@z_vec).t(), obs=data)
---> 12 pyro.render_model(ppca_model2, model_args=(torch.randn(150, 3), 1))

File ~/miniforge3/lib/python3.9/site-packages/pyro/infer/inspect.py:494, in render_model(model, model_args, model_kwargs, filename, render_distributions)
    471 def render_model(
    472     model: Callable,
    473     model_args: Optional[tuple] = None,
   (...)
    476     render_distributions: bool = False,
    477 ) -> "graphviz.Digraph":
    478     """
    479     Renders a model using `graphviz <https://graphviz.org>`_ .
    480 
   (...)
    492     :rtype: graphviz.Digraph
    493     """
--> 494     relations = get_model_relations(model, model_args, model_kwargs)
    495     graph_spec = generate_graph_specification(relations)
    496     graph = render_graph(graph_spec, render_distributions=render_distributions)

File ~/miniforge3/lib/python3.9/site-packages/pyro/infer/inspect.py:287, in get_model_relations(model, model_args, model_kwargs)
    283 if site["type"] != "sample" or site_is_subsample(site):
    284     continue
    285 sample_sample[name] = [
    286     upstream
--> 287     for upstream in get_provenance(site["fn"].log_prob(site["value"]))
    288     if upstream != name
    289 ]
    290 sample_dist[name] = _get_dist_name(site["fn"])
    291 for frame in site["cond_indep_stack"]:

AttributeError: 'ProvenanceTensor' object has no attribute 'log_prob'
dist.Normal(loc = torch.tensor([0.]), scale = 1.).sample()
tensor([-1.2529])
pyro.clear_param_store()

D = 2
d = 1

data = torch.zeros(100, D)

def ppca(data):
    A = pyro.param("A", torch.zeros((D, d)))
    mu = pyro.param("mu", torch.zeros(D))

    for i in pyro.plate("data", len(data)):
        z = pyro.sample("latent_{}".format(i), dist.Normal(torch.zeros(d), 1.0).to_event(1))
        pyro.sample("observed_{}".format(i), dist.Normal(A @ z + mu, 1.0).to_event(1), obs=data[i])

pyro.render_model(ppca, model_kwargs={'data':data})
../../_images/2022_02_17_pyro_linreg_19_0.svg
ppca(data)
data.shape
torch.Size([100, 2])
N = 1000
x = np.random.normal(loc = 5, scale = 1., size = N)
o = {}
for i in range(N-2):
    o[i] = x[:i+2].std()
pd.Series(o).plot()
plt.axhline(y=1)
sns.despine()
../../_images/2022_02_17_pyro_linreg_24_0.png
x2 = np.array(list(o.values()))
o2 = {}
for i in range(N-3):
    o2[i] = x2[:i+2].std()
pd.Series(o2).plot()
<AxesSubplot:>
../../_images/2022_02_17_pyro_linreg_27_1.png