JAX is an accelerated linear algebra library designed for mega fast numerical computing on modern hardware such as GPUs and TPUs. To understand JAX, its name can be broken down into its core components:
J (just-in-time compilation)
A (autograd)
X (accelerated linear algebra)
J: Just-in-Time Compilation
JAX’s just-in-time (JIT) compilation transforms Python functions into a low-level representation called Jaxprs (JAX expressions). These Jaxprs are compiled lazily—just before execution—into optimized code for GPUs or TPUs, significantly boosting performance.
Key Features:
Jaxpr Representation: Functions are broken down into primitive operations.
Lazy Compilation: Compilation happens only when needed, balancing flexibility and speed.
Functional Paradigm: JIT compilation works best with pure functions, aligning with JAX’s design philosophy.
import jax
import jax.numpy as jnp
# Define a function to compute the square of an array
def square_array(x):
return x**2
# Apply JIT compilation
jit_square = jax.jit(square_array)
# Create an input array
x = jnp.array([1, 2, 3, 4])
# Run the JIT-compiled function
result = jit_square(x)
print(result) # Output: [1, 4, 9, 16]A: Autograd for Automatic Differentiation
JAX’s autograd capabilities, inherited from the original Autograd library, make it a cornerstone for machine learning. Autograd automatically computes the derivative of Python functions, which is critical for tasks like gradient-based optimization and backpropagation in neural networks.
Key Features:
Gradient Computation: The jax.grad function takes a function and returns a new function that computes its derivative.
Higher-Order Derivatives: Since jax.grad returns a function, it can be applied repeatedly to compute higher-order derivatives.
Machine Learning Applications: Gradients are used to update model parameters, making JAX ideal for building neural networks (often with libraries like Flax).
import jax
import jax.numpy as jnp
# Define a simple function: f(x) = x^2 + 2x + 1
def f(x):
return x**2 + 2*x + 1
# Compute the derivative of f(x)
df = jax.grad(f)
# Evaluate the derivative at x = 3
x = 3.0
print(df(x)) # Output: 8.0 (derivative: 2x + 2, so 2*3 + 2 = 8)
# Compute the second derivative
d2f = jax.grad(df)
print(d2f(x)) # Output: 2.0 (second derivative: 2)X: Accelerated Linear Algebra
At its core, JAX provides a NumPy-like interface for creating and manipulating multi-dimensional arrays, supporting operations like element-wise addition, multiplication, and dot products. However, JAX introduces key constraints that set it apart from NumPy, enabling seamless compilation to low-level code for GPUs and TPUs.
Key Features:
Immutable Arrays: JAX arrays cannot be modified in place. Instead, you use the .at method to select indices and .set to update values, returning a new array.
Pure Functions: JAX enforces functional programming principles, ensuring operations are predictable and optimizable.
import jax.numpy as jnp
# Create a JAX array
arr = jnp.array([1, 2, 3, 4])
# Attempting to mutate directly (will raise an error)
# arr[0] = 10 # Error: JAX arrays are immutable
# Correct way: Use .at and .set to create a new array
new_arr = arr.at[0].set(10)
print(new_arr) # Output: [10, 2, 3, 4]

