ICLR 2020 Video | Paper | Quickstart | Install guide | Reference docs | Release notes
Neural Tangents is a high-level neural network API for specifying complex, hierarchical, neural networks of both finite and infinite width. Neural Tangents allows researchers to define, train, and evaluate infinite networks as easily as finite ones.
Infinite (in width or channel count) neural networks are Gaussian Processes (GPs) with a kernel function determined by their architecture (see References for details and nuances of this correspondence).
Neural Tangents allows you to construct a neural network model with the usual building blocks like convolutions, pooling, residual connections, nonlinearities etc. and obtain not only the finite model, but also the kernel function of the respective GP.
The library is written in python using JAX and leveraging XLA to run out-of-the-box on CPU, GPU, or TPU. Kernel computation is highly optimized for speed and memory efficiency, and can be automatically distributed over multiple accelerators with near-perfect scaling.
Neural Tangents is a work in progress. We happily welcome contributions!
An easy way to get started with Neural Tangents is by playing around with the following interactive notebooks in Colaboratory. They demo the major features of Neural Tangents and show how it can be used in research.
To use GPU, first follow JAX's GPU installation instructions. Otherwise, install JAX on CPU by running
pip install jax jaxlib --upgrade
Once JAX is installed install Neural Tangents by running
pip install neural-tangents
or, to use the bleeding-edge version from GitHub source,
git clone https://github.com/google/neural-tangents; cd neural-tangents
pip install -e .
You can now run the examples (using tensorflow_datasets
)
and tests by calling:
pip install tensorflow tensorflow-datasets more-itertools --upgrade
python examples/infinite_fcn.py
python examples/weight_space.py
python examples/function_space.py
set -e; for f in tests/*.py; do python $f; done
See this Colab for a detailed tutorial. Below is a very quick introduction.
Our library closely follows JAX's API for specifying neural networks, stax
. In stax
a network is defined by a pair of functions (init_fn, apply_fn)
initializing the trainable parameters and computing the outputs of the network respectively. Below is an example of defining a 3-layer network and computing it's outputs y
given inputs x
.
from jax import random
from jax.experimental import stax
init_fn, apply_fn = stax.serial(
stax.Dense(512), stax.Relu,
stax.Dense(512), stax.Relu,
stax.Dense(1)
)
key = random.PRNGKey(1)
x = random.normal(key, (10, 100))
_, params = init_fn(key, input_shape=x.shape)
y = apply_fn(params, x) # (10, 1) np.ndarray outputs of the neural network
Neural Tangents is designed to serve as a drop-in replacement for stax
, extending the (init_fn, apply_fn)
tuple to a triple (init_fn, apply_fn, kernel_fn)
, where kernel_fn
is the kernel function of the infinite network (GP) of the given architecture. Below is an example of computing the covariances of the GP between two batches of inputs x1
and x2
.
from jax import random
from neural_tangents import stax
init_fn, apply_fn, kernel_fn = stax.serial(
stax.Dense(512), stax.Relu(),
stax.Dense(512), stax.Relu(),
stax.Dense(1)
)
key1, key2 = random.split(random.PRNGKey(1))
x1 = random.normal(key1, (10, 100))
x2 = random.normal(key2, (20, 100))
kernel = kernel_fn(x1, x2, 'nngp')
Note that kernel_fn
can compute two covariance matrices corresponding to the Neural Network Gaussian Process (NNGP) and Neural Tangent (NT) kernels respectively. The NNGP kernel corresponds to the Bayesian infinite neural network [1-5]. The NTK corresponds to the (continuous) gradient descent trained infinite network [10]. In the above example, we compute the NNGP kernel but we could compute the NTK or both:
# Get kernel of a single type
nngp = kernel_fn(x1, x2, 'nngp') # (10, 20) np.ndarray
ntk = kernel_fn(x1, x2, 'ntk') # (10, 20) np.ndarray
# Get kernels as a namedtuple
both = kernel_fn(x1, x2, ('nngp', 'ntk'))
both.nngp == nngp # True
both.ntk == ntk # True
# Unpack the kernels namedtuple
nngp, ntk = kernel_fn(x1, x2, ('nngp', 'ntk'))
Additionally, if no third-argument is specified then the kernel_fn
will return a Kernel
namedtuple that contains additional metadata. This can be useful for composing applications of kernel_fn
as follows:
kernel = kernel_fn(x1, x2)
kernel = kernel_fn(kernel)
print(kernel.nngp)
Doing inference with infinite networks trained on MSE loss reduces to classical GP inference, for which we also provide convenient tools:
import neural_tangents as nt
x_train, x_test = x1, x2
y_train = random.uniform(key1, shape=(10, 1)) # training targets
predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, x_train,
y_train)
y_test_nngp = predict_fn(x_test=x_test, get='nngp')
# (20, 1) np.ndarray test predictions of an infinite Bayesian network
y_test_ntk = predict_fn(x_test=x_test, get='ntk')
# (20, 1) np.ndarray test predictions of an infinite continuous
# gradient descent trained network at convergence (t = inf)
# Get predictions as a namedtuple
both = predict_fn(x_test=x_test, get=('nngp', 'ntk'))
both.nngp == y_test_nngp # True
both.ntk == y_test_ntk # True
# Unpack the predictions namedtuple
y_test_nngp, y_test_ntk = predict_fn(x_test=x_test, get=('nngp', 'ntk'))
We can define a more compex, (infinitely) Wide Residual Network [14] using the same nt.stax
building blocks:
from neural_tangents import stax
def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False):
Main = stax.serial(
stax.Relu(), stax.Conv(channels, (3, 3), strides, padding='SAME'),
stax.Relu(), stax.Conv(channels, (3, 3), padding='SAME'))
Shortcut = stax.Identity() if not channel_mismatch else stax.Conv(
channels, (3, 3), strides, padding='SAME')
return stax.serial(stax.FanOut(2),
stax.parallel(Main, Shortcut),
stax.FanInSum())
def WideResnetGroup(n, channels, strides=(1, 1)):
blocks = []
blocks += [WideResnetBlock(channels, strides, channel_mismatch=True)]
for _ in range(n - 1):
blocks += [WideResnetBlock(channels, (1, 1))]
return stax.serial(*blocks)
def WideResnet(block_size, k, num_classes):
return stax.serial(
stax.Conv(16, (3, 3), padding='SAME'),
WideResnetGroup(block_size, int(16 * k)),
WideResnetGroup(block_size, int(32 * k), (2, 2)),
WideResnetGroup(block_size, int(64 * k), (2, 2)),
stax.AvgPool((8, 8)),
stax.Flatten(),
stax.Dense(num_classes, 1., 0.))
init_fn, apply_fn, kernel_fn = WideResnet(block_size=4, k=1, num_classes=10)
The neural_tangents
(nt
) package contains the following modules and functions:
stax
- primitives to construct neural networks like Conv
, Relu
, serial
, parallel
etc.
predict
- predictions with infinite networks:
predict.gradient_descent_mse
- inference with a single infinite width / linearized network trained on MSE loss with continuous gradient descent for an arbitrary finite or infinite (t=None
) time. Computed in closed form.
predict.gradient_descent
- inference with a single infinite width / linearized network trained on arbitrary loss with continuous (momentum) gradient descent for an arbitrary finite time. Computed using an ODE solver.
predict.gradient_descent_mse_ensemble
- inference with an infinite ensemble of infinite width networks, either fully Bayesian (get='nngp'
) or inference with MSE loss using continuous gradient descent (get='ntk'
). Finite-time Bayesian inference (e.g. t=1., get='nngp'
) is interpreted as gradient descent on the top layer only [11], since it converges to exact Gaussian process inference with NNGP (t=None, get='nngp'
). Computed in closed form.
predict.gp_inference
- exact closed form Gaussian process inference using NNGP (get='nngp'
), NTK (get='ntk'
), or both (get=('nngp', 'ntk')
). Equivalent to predict.gradient_descent_mse_ensemble
with t=None
(infinite training time), but has a slightly different API (accepting precomputed kernel matrix k_train_train
instead of kernel_fn
and x_train
).
monte_carlo_kernel_fn
- compute a Monte Carlo kernel estimate of any (init_fn, apply_fn)
, not necessarily specified via nt.stax
, enabling the kernel computation of infinite networks without closed-form expressions.
Tools to investigate training dynamics of wide but finite neural networks, like linearize
, taylor_expand
, empirical_kernel_fn
and more. See Training dynamics of wide but finite networks for details.
nt.stax
vs jax.experimental.stax
We remark the following differences between our library and the JAX one.
nt.stax
layers are instantiated with a function call, i.e. nt.stax.Relu()
vs jax.experimental.stax.Relu
.parameterization
keyword argument (see [15]).nt.stax
and jax.experimental.stax
may have different layers and options available (for example nt.stax
layers support CIRCULAR
padding, have LayerNorm
, but no BatchNorm
.).For CNNs w/ pooling, our CPU and TPU performance is suboptimal due to low core utilization (10-20%, looks like an XLA:CPU issue), and excessive padding respectively. We will look into improving performance, but recommend NVIDIA GPUs in the meantime. See Performance.
The kernel of an infinite network kernel_fn(x1, x2).ntk
combined with nt.predict.gradient_descent_mse
together allow to analytically track the outputs of an infinitely wide neural network trained on MSE loss througout training. Here we discuss the implications for wide but finite neural networks and present tools to study their evolution in weight space (trainable parameters of the network) and function space (outputs of the network).
Continuous gradient descent in an infinite network has been shown in [11] to correspond to training a linear (in trainable parameters) model, which makes linearized neural networks an important subject of study for understanding the behavior of parameters in wide models.
For this, we provide two convenient functions:
nt.linearize
, andnt.taylor_expand
,which allow to linearize or get an arbitrary-order Taylor expansion of any function apply_fn(params, x)
around some initial parameters params_0
as apply_fn_lin = nt.linearize(apply_fn, params_0)
.
One can use apply_fn_lin(params, x)
exactly as you would any other function
(including as an input to JAX optimizers). This makes it easy to compare the
training trajectory of neural networks with that of its linearization.
Previous theory and experiments have examined the linearization of neural
networks from inputs to logits or pre-activations, rather than from inputs to
post-activations which are substantially more nonlinear.
import jax.numpy as np
import neural_tangents as nt
def apply_fn(params, x):
W, b = params
return np.dot(x, W) + b
W_0 = np.array([[1., 0.], [0., 1.]])
b_0 = np.zeros((2,))
apply_fn_lin = nt.linearize(apply_fn, (W_0, b_0))
W = np.array([[1.5, 0.2], [0.1, 0.9]])
b = b_0 + 0.2
x = np.array([[0.3, 0.2], [0.4, 0.5], [1.2, 0.2]])
logits = apply_fn_lin((W, b), x) # (3, 2) np.ndarray
Outputs of a linearized model evolve identically to those of an infinite one [11] but with a different kernel - specifically, the Neural Tangent Kernel [10] evaluated on the specific apply_fn
of the finite network given specific params_0
that the network is initialized with. For this we provide the nt.empirical_kernel_fn
function that accepts any apply_fn
and returns a kernel_fn(x1, x2, get, params)
that allows to compute the empirical NTK and/or NNGP (based on get
) kernels on specific params
.
import jax.random as random
import jax.numpy as np
import neural_tangents as nt
def apply_fn(params, x):
W, b = params
return np.dot(x, W) + b
W_0 = np.array([[1., 0.], [0., 1.]])
b_0 = np.zeros((2,))
params = (W_0, b_0)
key1, key2 = random.split(random.PRNGKey(1), 2)
x_train = random.normal(key1, (3, 2))
x_test = random.normal(key2, (4, 2))
y_train = random.uniform(key1, shape=(3, 2))
kernel_fn = nt.empirical_kernel_fn(apply_fn)
ntk_train_train = kernel_fn(x_train, None, 'ntk', params)
ntk_test_train = kernel_fn(x_test, x_train, 'ntk', params)
mse_predictor = nt.predict.gradient_descent_mse(ntk_train_train, y_train)
t = 5.
y_train_0 = apply_fn(params, x_train)
y_test_0 = apply_fn(params, x_test)
y_train_t, y_test_t = mse_predictor(t, y_train_0, y_test_0, ntk_test_train)
# (3, 2) and (4, 2) np.ndarray train and test outputs after `t` units of time
# training with continuous gradient descent
The success or failure of the linear approximation is highly architecture dependent. However, some rules of thumb that we've observed are:
Convergence as the network size increases.
For fully-connected networks one generally observes very strong agreement by the time the layer-width is 512 (RMSE of about 0.05 at the end of training).
For convolutional networks one generally observes reasonable agreement agreement by the time the number of channels is 512.
Convergence at small learning rates.
With a new model it is therefore advisable to start with a very large model on a small dataset using a small learning rate.
In the table below we measure time to compute a single NTK
entry in a 21-layer CNN (3x3
filters, no strides, SAME
padding, ReLU
) on inputs of shape 3x32x32
. Precisely:
layers = []
for _ in range(21):
layers += [stax.Conv(1, (3, 3), (1, 1), 'SAME'), stax.Relu()]
Top layer is stax.GlobalAvgPool()
:
_, _, kernel_fn = stax.serial(*(layers + [stax.GlobalAvgPool()]))
Platform | Precision | Milliseconds / NTK entry | Max batch size (NxN ) |
---|---|---|---|
CPU, >56 cores, >700 Gb RAM | 32 | 112.90 | >= 128 |
CPU, >56 cores, >700 Gb RAM | 64 | 258.55 | 95 (fastest - 72) |
TPU v2 | 32/16 | 3.2550 | 16 |
TPU v3 | 32/16 | 2.3022 | 24 |
NVIDIA P100 | 32 | 5.9433 | 26 |
NVIDIA P100 | 64 | 11.349 | 18 |
NVIDIA V100 | 32 | 2.7001 | 26 |
NVIDIA V100 | 64 | 6.2058 | 18 |
Top layer is stax.Flatten()
:
_, _, kernel_fn = stax.serial(*(layers + [stax.Flatten()]))
Platform | Precision | Milliseconds / NTK entry | Max batch size (NxN ) |
---|---|---|---|
CPU, >56 cores, >700 Gb RAM | 32 | 0.12013 | 2048 <= N < 4096 (fastest - 512) |
CPU, >56 cores, >700 Gb RAM | 64 | 0.3414 | 2048 <= N < 4096 (fastest - 256) |
TPU v2 | 32/16 | 0.0015722 | 512 <= N < 1024 |
TPU v3 | 32/16 | 0.0010647 | 512 <= N < 1024 |
NVIDIA P100 | 32 | 0.015171 | 512 <= N < 1024 |
NVIDIA P100 | 64 | 0.019894 | 512 <= N < 1024 |
NVIDIA V100 | 32 | 0.0046510 | 512 <= N < 1024 |
NVIDIA V100 | 64 | 0.010822 | 512 <= N < 1024 |
Tested using version 0.2.1
. All GPU results are per single accelerator.
Note that runtime is proportional to the depth of your network.
If your performance differs significantly,
please file a bug!
Colab notebook Performance Benchmark
demonstrates how one would construct and benchmark kernels. To demonstrate
flexibility, we took architecture from [16]
as an example. With NVIDIA V100
64-bit precision, nt
took 316/330/508 GPU-hours on full 60k CIFAR-10 dataset for Myrtle-5/7/10 kernels.
Neural Tangents has been used in the following papers (newest first):
Please let us know if you make use of the code in a publication, and we'll add it to the list!
If you use the code in a publication, please cite our ICLR 2020 paper:
@inproceedings{neuraltangents2020,
title={Neural Tangents: Fast and Easy Infinite Neural Networks in Python},
author={Roman Novak and Lechao Xiao and Jiri Hron and Jaehoon Lee and Alexander A. Alemi and Jascha Sohl-Dickstein and Samuel S. Schoenholz},
booktitle={International Conference on Learning Representations},
year={2020},
url={https://github.com/google/neural-tangents}
}