JAX Framework Cheatsheet

JAX, short for “Just Another eXperimental library,” is a powerful and flexible framework developed by Google for numerical computing. It’s particularly popular in the machine learning community for its efficiency, automatic differentiation capabilities, and compatibility with hardware accelerators like GPUs and TPUs. In this cheatsheet, we’ll cover some essential concepts and code snippets to help you master JAX.

Getting Started

Installation

pip install jax jaxlib

Importing JAX

import jax
import jax.numpy as jnp

JAX Basics

NumPy Compatibility

JAX’s jax.numpy module is designed to be a drop-in replacement for NumPy, allowing you to use familiar syntax while leveraging JAX’s benefits.

# NumPy
import numpy as np

# JAX
import jax.numpy as jnp

Functional Programming

JAX encourages a functional programming style, making extensive use of functional transformations like map, vmap, and jax.lax.scan.

# Example: Elementwise square using vmap
vmap_square = jax.vmap(jnp.square)
result = vmap_square(jnp.array([1, 2, 3]))

Automatic Differentiation

JAX provides automatic differentiation through the grad function, allowing you to compute gradients effortlessly.

def square(x):
    return x ** 2

grad_square = jax.grad(square)
result_grad = grad_square(3.0)

Advanced JAX

JIT Compilation

JAX’s jit decorator allows you to compile functions for improved performance. This is especially beneficial for numerical operations within loops.

@jax.jit
def compiled_function(x):
    return x + 1

result = compiled_function(3)

Device Assignment

JAX seamlessly integrates with accelerators like GPUs and TPUs. Use jax.device_put to move arrays between devices.

# Assign to GPU
x_gpu = jax.device_put(x, jax.devices("gpu")[0])

# Assign to TPU
x_tpu = jax.device_put(x, jax.devices("tpu")[0])

Random Numbers

JAX’s random module (jax.random) provides a functional and deterministic approach to generating random numbers.

key = jax.random.PRNGKey(42)
random_values = jax.random.normal(key, shape=(3, 3))

Transformations

JAX offers powerful transformations like jax.jit, jax.grad, and jax.vmap for optimizing, differentiating, and vectorizing functions.

# Example: Composing transformations
@jax.jit
@jax.grad
def complex_function(x):
    return x ** 2 + 1

This cheatsheet serves as a quick reference for mastering JAX, highlighting its essential features and providing code snippets for common tasks. As you delve deeper into JAX, explore its extensive documentation and community resources to unlock its full potential in your numerical computing and machine learning projects.

FAQ

1. What is JAX, and how does it differ from other numerical computing frameworks?

JAX, short for “Just Another eXperimental library,” is a numerical computing framework developed by Google. It distinguishes itself through its seamless integration of automatic differentiation, functional programming style, and compatibility with hardware accelerators like GPUs and TPUs.

2. How can I accelerate my JAX code for better performance?

You can use JAX’s jit decorator to compile functions for improved performance. Additionally, consider leveraging hardware accelerators such as GPUs and TPUs by using jax.device_put to move arrays between devices.

3. What role does functional programming play in JAX?

JAX encourages a functional programming style, emphasizing pure functions and immutability. Functional transformations like map, vmap, and jax.lax.scan are integral to expressing computations in a way that facilitates parallelization and optimization.

4. How does JAX handle random number generation?

JAX’s random module (jax.random) provides a functional and deterministic approach to generating random numbers. You can create a key using jax.random.PRNGKey and then use it with functions like jax.random.normal to generate random values.

5. Can I use JAX for automatic differentiation, and how is it implemented?

Yes, JAX provides automatic differentiation through the grad function. Simply apply jax.grad to a function, and it will compute the gradient automatically. This feature is crucial for tasks like training machine learning models where gradient information is essential for optimization.