FedJAX is a JAX-based open source library for Federated Learning simulations that emphasizes ease-of-use in research.
Alternatives To Fedjax
Project NameStarsDownloadsRepos Using ThisPackages Using ThisMost Recent CommitTotal ReleasesLatest ReleaseOpen IssuesLicenseLanguage
2 days ago7February 25, 202283mitPython
A unified, comprehensive and efficient recommendation library
Transfer Learning Library2,364
7 days ago2July 24, 202012mitPython
Transfer Learning Library for Domain Adaptation, Task Adaptation, and Domain Generalization
7 days ago1mitPython
DomainBed is a suite to test domain generalization algorithms
Pmlb71527a month ago9October 13, 202012mitPython
PMLB: A large, curated repository of benchmark datasets for evaluating supervised machine learning algorithms.
7 months ago13mitPython
Python Implementation of Apriori Algorithm for finding Frequent sets and Association Rules
3 years ago1Python
Machine learning resources,including algorithm, paper, dataset, example and so on.
4 years ago1gpl-3.0Python
Zr Obp494
4 months ago16June 15, 202217apache-2.0Python
Open Bandit Pipeline: a python library for bandit algorithms and off-policy evaluation
2 years ago31April 02, 202133mitPython
A flexible source separation library in Python
Carefree Learn390
4 days ago35June 20, 2022mitPython
Deep Learning ❤️ PyTorch
Alternatives To Fedjax
Select To Compare

Alternative Project Comparisons

FedJAX: Federated learning simulation with JAX

Build and minimal test Documentation Status PyPI version

Documentation | Paper

NOTE: FedJAX is not an officially supported Google product. FedJAX is still in the early stages and the API will likely continue to change.

What is FedJAX?

FedJAX is a JAX-based open source library for Federated Learning simulations that emphasizes ease-of-use in research. With its simple primitives for implementing federated learning algorithms, prepackaged datasets, models and algorithms, and fast simulation speed, FedJAX aims to make developing and evaluating federated algorithms faster and easier for researchers. FedJAX works on accelerators (GPU and TPU) without much additional effort. Additional details and benchmarks can be found in our paper.


You will need a moderately recent version of Python. Please check the PyPI page for the up to date version requirement.

First, install JAX. For a CPU-only version:

pip install --upgrade pip
pip install --upgrade jax jaxlib  # CPU-only version

For other devices (e.g. GPU), follow these instructions.

Then, install FedJAX from PyPI:

pip install fedjax

Or, to upgrade to the latest version of FedJAX:

pip install --upgrade git+https://github.com/google/fedjax.git

Getting Started

Below is a simple example to verify FedJAX is installed correctly.

import fedjax
import jax
import jax.numpy as jnp
import numpy as np

# {'client_id': client_dataset}.
fd = fedjax.InMemoryFederatedData({
    'a': {
        'x': np.array([1.0, 2.0, 3.0]),
        'y': np.array([2.0, 4.0, 6.0]),
    'b': {
        'x': np.array([4.0]),
        'y': np.array([12.0])
# Initial model parameters.
params = jnp.array(0.5)
# Mean squared error.
mse_loss = lambda params, batch: jnp.mean(
    (jnp.dot(batch['x'], params) - batch['y'])**2)
# Loss for clients 'a' and 'b'.
print(f"client a loss = {mse_loss(params, fd.get_client('a').all_examples())}")
print(f"client b loss = {mse_loss(params, fd.get_client('b').all_examples())}")

The following tutorial notebooks provide an introduction to FedJAX:

You can also take a look at some of our working examples:

Citing FedJAX

To cite this repository:

  title={{F}ed{JAX}: Federated learning simulation with {JAX}},
  author={Jae Hun Ro and Ananda Theertha Suresh and Ke Wu},
  journal={arXiv preprint arXiv:2108.02117},

Useful pointers

Popular Algorithms Projects
Popular Dataset Projects
Popular Computer Science Categories
Related Searches

Get A Weekly Email With Trending Projects For These Categories
No Spam. Unsubscribe easily at any time.