Linear Regression in TF Probability using JointDistributionCoroutineAutoBatched#

  • toc: true

  • badges: true

  • comments: true

  • author: Nipun Batra

  • categories: [ML, TFP, TF]

Basic Imports#

from silence_tensorflow import silence_tensorflow
silence_tensorflow()

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import functools
import seaborn as sns
import tensorflow_probability as tfp
import pandas as pd

tfd = tfp.distributions
tfl = tfp.layers
tfb = tfp.bijectors

sns.reset_defaults()
sns.set_context(context="talk", font_scale=1)
%matplotlib inline
%config InlineBackend.figure_format='retina'
np.random.seed(0)
tf.random.set_seed(0)
def lr(x, stddv_datapoints):
    num_datapoints, data_dim = x.shape
    b = yield tfd.Normal(
        loc=0.0,
        scale=2.0,
        name="b",
    )
    w = yield tfd.Normal(
        loc=tf.zeros([data_dim]), scale=2.0 * tf.ones([data_dim]), name="w"
    )

    y = yield tfd.Normal(
        loc=tf.linalg.matvec(x, w) + b, scale=stddv_datapoints, name="y"
    )
x = tf.linspace(-5.0, 5.0, 100)
x = tf.expand_dims(x, 1)
stddv_datapoints = 1

concrete_lr_model = functools.partial(lr, x=x, stddv_datapoints=stddv_datapoints)

model = tfd.JointDistributionCoroutineAutoBatched(concrete_lr_model)
model
<tfp.distributions.JointDistributionCoroutineAutoBatched 'JointDistributionCoroutineAutoBatched' batch_shape=[] event_shape=StructTuple(
  b=[],
  w=[1],
  y=[100]
) dtype=StructTuple(
  b=float32,
  w=float32,
  y=float32
)>
actual_b, actual_w, y_train = model.sample()
plt.scatter(x, y_train, s=30, alpha=0.6)
plt.plot(x, tf.linalg.matvec(x, actual_w) + actual_b, color="k")
sns.despine()
../../_images/089b6c2bf0c78a2c8def4eac8ab32de146b73a17b16b080f6155eae9e1411a3b.png
trace_fn = lambda traceable_quantities: {
    "loss": traceable_quantities.loss,
    "w": w,
    "b": b,
}
data_dim = 1
w = tf.Variable(tf.zeros_like(actual_w))

b = tf.Variable(tf.zeros_like(actual_b))

target_log_prob_fn = lambda w, b: model.log_prob((b, w, y_train))
target_log_prob_fn
<function __main__.<lambda>(w, b)>
trace = tfp.math.minimize(
    lambda: -target_log_prob_fn(w, b),
    optimizer=tf.optimizers.Adam(learning_rate=0.05),
    trace_fn=trace_fn,
    num_steps=200,
)
w, b, actual_w, actual_b
(<tf.Variable 'Variable:0' shape=(1,) dtype=float32, numpy=array([2.1811483], dtype=float32)>,
 <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=3.0149531>,
 <tf.Tensor: shape=(1,), dtype=float32, numpy=array([2.1337605], dtype=float32)>,
 <tf.Tensor: shape=(), dtype=float32, numpy=3.0221252>)
plt.plot(trace["w"], label="w")
plt.plot(trace["b"], label="b")
plt.legend()
sns.despine()
../../_images/afd695cc39c1389f14df9cfb560184ca784d8fda7faeb08cd672465b16c3320b.png
qw_mean = tf.Variable(tf.random.normal([data_dim]))
qb_mean = tf.Variable(tf.random.normal([1]))
qw_stddv = tfp.util.TransformedVariable(
    1e-4 * tf.ones([data_dim]), bijector=tfb.Softplus()
)
qb_stddv = tfp.util.TransformedVariable(1e-4 * tf.ones([1]), bijector=tfb.Softplus())
def factored_normal_variational_model():
    qw = yield tfd.Normal(loc=qw_mean, scale=qw_stddv, name="qw")
    qb = yield tfd.Normal(loc=qb_mean, scale=qb_stddv, name="qb")


surrogate_posterior = tfd.JointDistributionCoroutineAutoBatched(
    factored_normal_variational_model
)

losses = tfp.vi.fit_surrogate_posterior(
    target_log_prob_fn,
    surrogate_posterior=surrogate_posterior,
    optimizer=tf.optimizers.Adam(learning_rate=0.05),
    num_steps=200,
)
/Users/nipun/miniforge3/lib/python3.9/site-packages/tensorflow_probability/python/internal/vectorization_util.py:87: UserWarning: Saw Tensor seed Tensor("seed:0", shape=(2,), dtype=int32), implying stateless sampling. Autovectorized functions that use stateless sampling may be quite slow because the current implementation falls back to an explicit loop. This will be fixed in the future. For now, you will likely see better performance from stateful sampling, which you can invoke by passing a Python `int` seed.
  warnings.warn(
/Users/nipun/miniforge3/lib/python3.9/site-packages/tensorflow_probability/python/internal/vectorization_util.py:87: UserWarning: Saw Tensor seed Tensor("seed:0", shape=(2,), dtype=int32), implying stateless sampling. Autovectorized functions that use stateless sampling may be quite slow because the current implementation falls back to an explicit loop. This will be fixed in the future. For now, you will likely see better performance from stateful sampling, which you can invoke by passing a Python `int` seed.
  warnings.warn(
qw_mean, qw_stddv, qb_mean, qb_stddv
(<tf.Variable 'Variable:0' shape=(1,) dtype=float32, numpy=array([2.1905935], dtype=float32)>,
 <TransformedVariable: name=softplus, dtype=float32, shape=[1], fn="softplus", numpy=array([0.04352505], dtype=float32)>,
 <tf.Variable 'Variable:0' shape=(1,) dtype=float32, numpy=array([3.0260112], dtype=float32)>,
 <TransformedVariable: name=softplus, dtype=float32, shape=[1], fn="softplus", numpy=array([0.09258726], dtype=float32)>)
s_qw, s_qb = surrogate_posterior.sample(500)
ys = tf.linalg.matvec(x, s_qw) + s_qb
x.shape, ys.shape
(TensorShape([100, 1]), TensorShape([500, 100]))
plt.plot(x, ys.numpy().T, color='k', alpha=0.05);
plt.scatter(x, y_train, s=30, alpha=0.6)
sns.despine()
../../_images/29df4581b85e533acbeb5cf5220e653ec34dd879547087f2ce2b48aed77dc66b.png

TODO

  1. How to replace x in lr function from x_train to x_test?

References