Definition
vmap
vmap is a function transformation that automates batch processing by vectorising operations. It allows you to write code for individual data samples and then automatically apply it to batches, improving code readability and computational efficiency.
Example
import jax
import jax.numpy as jnp
import numpy as np
def predict(params: tuple[np.ndarray, np.ndarray], x: jax.Array) -> jax.Array:
w, b = params
return jnp.dot(x, w) + b # x is a single input vector
# Vectorize over the input `x` (batch axis=0), leave `params` unbatched
batched_predict = jax.vmap(predict, in_axes=(None, 0))
# Inputs:
params = (w, b) # Shared across batch
x_batch = jnp.ones((128, 64)) # Batch of 128 examples, each with 64 features
# Compute predictions for all 128 examples in parallel:
output = batched_predict(params, x_batch) # Shape: (128,)