Linear Regression in TF Probability using JointDistributionCoroutineAutoBatched
Contents
Linear Regression in TF Probability using JointDistributionCoroutineAutoBatched¶
toc: true
badges: true
comments: true
author: Nipun Batra
categories: [ML, TFP, TF]
Basic Imports¶
from silence_tensorflow import silence_tensorflow
silence_tensorflow()
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import functools
import seaborn as sns
import tensorflow_probability as tfp
import pandas as pd
tfd = tfp.distributions
tfl = tfp.layers
tfb = tfp.bijectors
sns.reset_defaults()
sns.set_context(context="talk", font_scale=1)
%matplotlib inline
%config InlineBackend.figure_format='retina'
np.random.seed(0)
tf.random.set_seed(0)
def lr(x, stddv_datapoints):
num_datapoints, data_dim = x.shape
b = yield tfd.Normal(
loc=0.0,
scale=2.0,
name="b",
)
w = yield tfd.Normal(
loc=tf.zeros([data_dim]), scale=2.0 * tf.ones([data_dim]), name="w"
)
y = yield tfd.Normal(
loc=tf.linalg.matvec(x, w) + b, scale=stddv_datapoints, name="y"
)
x = tf.linspace(-5.0, 5.0, 100)
x = tf.expand_dims(x, 1)
stddv_datapoints = 1
concrete_lr_model = functools.partial(lr, x=x, stddv_datapoints=stddv_datapoints)
model = tfd.JointDistributionCoroutineAutoBatched(concrete_lr_model)
model
<tfp.distributions.JointDistributionCoroutineAutoBatched 'JointDistributionCoroutineAutoBatched' batch_shape=[] event_shape=StructTuple(
b=[],
w=[1],
y=[100]
) dtype=StructTuple(
b=float32,
w=float32,
y=float32
)>
actual_b, actual_w, y_train = model.sample()
plt.scatter(x, y_train, s=30, alpha=0.6)
plt.plot(x, tf.linalg.matvec(x, actual_w) + actual_b, color="k")
sns.despine()
trace_fn = lambda traceable_quantities: {
"loss": traceable_quantities.loss,
"w": w,
"b": b,
}
data_dim = 1
w = tf.Variable(tf.zeros_like(actual_w))
b = tf.Variable(tf.zeros_like(actual_b))
target_log_prob_fn = lambda w, b: model.log_prob((b, w, y_train))
target_log_prob_fn
<function __main__.<lambda>(w, b)>
trace = tfp.math.minimize(
lambda: -target_log_prob_fn(w, b),
optimizer=tf.optimizers.Adam(learning_rate=0.05),
trace_fn=trace_fn,
num_steps=200,
)
w, b, actual_w, actual_b
(<tf.Variable 'Variable:0' shape=(1,) dtype=float32, numpy=array([2.1811483], dtype=float32)>,
<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=3.0149531>,
<tf.Tensor: shape=(1,), dtype=float32, numpy=array([2.1337605], dtype=float32)>,
<tf.Tensor: shape=(), dtype=float32, numpy=3.0221252>)
plt.plot(trace["w"], label="w")
plt.plot(trace["b"], label="b")
plt.legend()
sns.despine()
qw_mean = tf.Variable(tf.random.normal([data_dim]))
qb_mean = tf.Variable(tf.random.normal([1]))
qw_stddv = tfp.util.TransformedVariable(
1e-4 * tf.ones([data_dim]), bijector=tfb.Softplus()
)
qb_stddv = tfp.util.TransformedVariable(1e-4 * tf.ones([1]), bijector=tfb.Softplus())
def factored_normal_variational_model():
qw = yield tfd.Normal(loc=qw_mean, scale=qw_stddv, name="qw")
qb = yield tfd.Normal(loc=qb_mean, scale=qb_stddv, name="qb")
surrogate_posterior = tfd.JointDistributionCoroutineAutoBatched(
factored_normal_variational_model
)
losses = tfp.vi.fit_surrogate_posterior(
target_log_prob_fn,
surrogate_posterior=surrogate_posterior,
optimizer=tf.optimizers.Adam(learning_rate=0.05),
num_steps=200,
)
/Users/nipun/miniforge3/lib/python3.9/site-packages/tensorflow_probability/python/internal/vectorization_util.py:87: UserWarning: Saw Tensor seed Tensor("seed:0", shape=(2,), dtype=int32), implying stateless sampling. Autovectorized functions that use stateless sampling may be quite slow because the current implementation falls back to an explicit loop. This will be fixed in the future. For now, you will likely see better performance from stateful sampling, which you can invoke by passing a Python `int` seed.
warnings.warn(
/Users/nipun/miniforge3/lib/python3.9/site-packages/tensorflow_probability/python/internal/vectorization_util.py:87: UserWarning: Saw Tensor seed Tensor("seed:0", shape=(2,), dtype=int32), implying stateless sampling. Autovectorized functions that use stateless sampling may be quite slow because the current implementation falls back to an explicit loop. This will be fixed in the future. For now, you will likely see better performance from stateful sampling, which you can invoke by passing a Python `int` seed.
warnings.warn(
qw_mean, qw_stddv, qb_mean, qb_stddv
(<tf.Variable 'Variable:0' shape=(1,) dtype=float32, numpy=array([2.1905935], dtype=float32)>,
<TransformedVariable: name=softplus, dtype=float32, shape=[1], fn="softplus", numpy=array([0.04352505], dtype=float32)>,
<tf.Variable 'Variable:0' shape=(1,) dtype=float32, numpy=array([3.0260112], dtype=float32)>,
<TransformedVariable: name=softplus, dtype=float32, shape=[1], fn="softplus", numpy=array([0.09258726], dtype=float32)>)
s_qw, s_qb = surrogate_posterior.sample(500)
ys = tf.linalg.matvec(x, s_qw) + s_qb
x.shape, ys.shape
(TensorShape([100, 1]), TensorShape([500, 100]))
plt.plot(x, ys.numpy().T, color='k', alpha=0.05);
plt.scatter(x, y_train, s=30, alpha=0.6)
sns.despine()
TODO
How to replace x in
lr
function from x_train to x_test?
References