# Autograd in JAX and PyTorch

#### Basic Imports

In [1]:
import torch
from jax import grad
import jax.numpy as jnp

#### Creating scalar variables in PyTorch

In [2]:
x_torch = torch.autograd.Variable(torch.tensor(1.), requires_grad=True)
y_torch = torch.autograd.Variable(torch.tensor(1.), requires_grad=True)

#### Creating scalar variables in JAX

In [3]:
x_jax = jnp.array(1.)
y_jax = jnp.array(1.)



#### Defining a loss on scalar inputs

In [4]:
def loss(x, y):
    return x*x + y*y

#### Computing the loss on PyTorch input

In [5]:
l_torch  = loss(x_torch, y_torch)
l_torch

tensor(2., grad_fn=<AddBackward0>)

#### Computing the loss on JAX input

In [6]:
l_jax = loss(x_jax, y_jax)

#### Computing the gradient on PyTorch input

In [7]:
l_torch.backward()
x_torch.grad, y_torch.grad

(tensor(2.), tensor(2.))

#### Computing the gradient on JAX input

In [8]:
grad_loss = grad(loss, argnums=[0, 1])
grad_loss(x_jax, y_jax)

(DeviceArray(2., dtype=float32, weak_type=True),
 DeviceArray(2., dtype=float32, weak_type=True))

#### Repeating the same procedure as above for both libraries but instead using vector function

In [9]:
def loss(theta):
    return theta.T@theta

In [10]:
theta_torch = torch.autograd.Variable(torch.tensor([1., 1.]), requires_grad=True)

In [11]:
theta_torch

tensor([1., 1.], requires_grad=True)

In [12]:
l = loss(theta_torch)
l

tensor(2., grad_fn=<DotBackward0>)

In [13]:
l.backward()
theta_torch.grad

tensor([2., 2.])

In [14]:
theta_jax = jnp.array([1., 1.])

In [15]:
loss(theta_jax)

DeviceArray(2., dtype=float32)

In [16]:
grad_loss = grad(loss, argnums=[0])

In [17]:
grad_loss(theta_jax)

(DeviceArray([2., 2.], dtype=float32),)