import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
import jax
import numpy as np
import optax
jax.devices()
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[CpuDevice(id=0)]
# y_hat = x^t * theta for a single point

def pred_yhat(x, theta):
    return jnp.dot(x, theta[1:]) + theta[0]
x = jnp.array([2., 2.])
theta = jnp.array([-1., 2., 3.])
pred_yhat(x, theta)
DeviceArray(9., dtype=float32)

Using PRNG key¶

key = random.PRNGKey(0)
X = random.normal(key, (100, 2))
print(X.shape)
(100, 2)

VMAP for auto-batching!¶

%timeit vmap(pred_yhat, in_axes=(0, None))(X, theta).block_until_ready()
1.31 ms ± 11.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
%timeit jnp.stack([pred_yhat(x, theta) for x in X]).block_until_ready()
98 ms ± 1.15 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit X@theta[1:] + theta[0]
893 µs ± 2.61 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

JIT for speedup¶

%timeit vmap(jit(pred_yhat), in_axes=(0, None))(X, theta).block_until_ready()
289 µs ± 2.31 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
pred_y_hat_vector = lambda X, theta: vmap(jit(pred_yhat), in_axes=(0, None))(X, theta)
def cost(X, y, theta):
    y_hat_vec = pred_y_hat_vector(X, theta)
    error_square = jnp.linalg.norm(y - y_hat_vec, 2)
    return error_square
theta_gt = jnp.array([1., 4., 5.])
y_gt = vmap(jit(pred_yhat), in_axes=(0, None))(X, theta_gt) + 0.2*random.normal(key, (100, 1))
cost(X, y_gt, theta_gt)
DeviceArray(18.885782, dtype=float32)

Our initial estimates (theta) are not good

cost(X, y_gt, theta)
DeviceArray(377.96906, dtype=float32)
grad_theta = grad(cost, argnums=[2])
lr = 0.001
for i in range(50):
    cost_val = cost(X, y_gt, theta)
    print(i, cost_val)
    grad_theta_val = grad_theta(X, y_gt, theta)[0]
    theta = theta - lr*grad_theta_val
0 377.96906
1 365.99265
2 354.02716
3 342.07303
4 330.13083
5 318.20102
6 306.28424
7 294.38098
8 282.49225
9 270.61868
10 258.76102
11 246.92035
12 235.09787
13 223.29465
14 211.51213
15 199.75198
16 188.01602
17 176.30638
18 164.62575
19 152.97719
20 141.36456
21 129.7927
22 118.26772
23 106.79769
24 95.3937
25 84.071465
26 72.85466
27 61.781315
28 50.91854
29 40.401173
30 30.558014
31 22.45481
32 18.97314
33 18.814388
34 18.81421
35 18.814215
36 18.814215
37 18.814215
38 18.814215
39 18.814215
40 18.814215
41 18.814215
42 18.814215
43 18.814215
44 18.814215
45 18.814215
46 18.814215
47 18.814215
48 18.814215
49 18.814215
theta
DeviceArray([1.0210268, 3.881022 , 4.9754868], dtype=float32)
theta_gt
DeviceArray([1., 4., 5.], dtype=float32)

Using Optax instead of manually writing SGD¶

#optimizer = optax.adam(learning_rate=0.01)
optimizer = optax.sgd(learning_rate=0.001)

theta = jnp.array([-1., 2., 3.])
opt_state = optimizer.init(theta)
opt_state
(EmptyState(), EmptyState())
for i in range(50):
    cost_val = cost(X, y_gt, theta)
    print(i, cost_val)
    grad_theta_val = grad_theta(X, y_gt, theta)[0]
    updates, opt_state = optimizer.update(grad_theta_val, opt_state)
    theta = optax.apply_updates(theta, updates)
0 377.96906
1 365.99265
2 354.02716
3 342.07303
4 330.13083
5 318.20102
6 306.28424
7 294.38098
8 282.49225
9 270.61868
10 258.76102
11 246.92035
12 235.09787
13 223.29465
14 211.51213
15 199.75198
16 188.01602
17 176.30638
18 164.62575
19 152.97719
20 141.36456
21 129.7927
22 118.26772
23 106.79769
24 95.3937
25 84.071465
26 72.85466
27 61.781315
28 50.91854
29 40.401173
30 30.558014
31 22.45481
32 18.97314
33 18.814388
34 18.81421
35 18.814215
36 18.814215
37 18.814215
38 18.814215
39 18.814215
40 18.814215
41 18.814215
42 18.814215
43 18.814215
44 18.814215
45 18.814215
46 18.814215
47 18.814215
48 18.814215
49 18.814215

Is JAX quicker (even on CPU?!)¶

Gaussian Processes need Cholesky decompositions. Can we get a speedup using JAX instead of Numpy?

a = np.random.randn(1000, 1000)
b = a.T@a
%timeit np.linalg.cholesky(b)
6.82 ms ± 464 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
b = jnp.array(b)
%timeit jnp.linalg.cholesky(b)
2.19 ms ± 67.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)