Documentation Index
Fetch the complete documentation index at: https://wb-21fd5541-kb-refresh.mintlify.app/llms.txt
Use this file to discover all available pages before exploring further.
W&B does not have a JAX-specific integration, but works well with JAX training loops through direct calls to wandb.log(). The main consideration is that JAX device arrays must be converted to Python scalars before logging.
Basic setup
import jax
import jax.numpy as jnp
import wandb
wandb.init(project="my-jax-project", config={
"learning_rate": 1e-3,
"batch_size": 64,
"epochs": 50,
})
# Access config
lr = wandb.config.learning_rate
Converting JAX arrays before logging
JAX values returned from jit-compiled functions are device arrays. Call .item() or wrap in float() before passing to wandb.log():
@jax.jit
def train_step(params, batch):
loss, grads = jax.value_and_grad(loss_fn)(params, batch)
params = update_params(params, grads)
return params, loss
for step, batch in enumerate(dataloader):
params, loss = train_step(params, batch)
# loss is a JAX array — convert before logging
wandb.log({"train/loss": float(loss)}, step=step)
Passing a JAX device array directly to wandb.log() without conversion may cause serialization errors or silently log incorrect values.
Logging at epoch boundaries
For validation metrics that require a full pass over the validation set, aggregate in Python and log once per epoch:
for epoch in range(num_epochs):
# training loop ...
val_losses = []
for val_batch in val_loader:
val_loss = eval_step(params, val_batch)
val_losses.append(float(val_loss))
wandb.log({
"epoch": epoch,
"val/loss": sum(val_losses) / len(val_losses),
})
Logging model checkpoints as artifacts
Save JAX/Flax model parameters with orbax or flax.serialization and log them as W&B artifacts:
import orbax.checkpoint as ocp
# Save checkpoint to disk
checkpointer = ocp.StandardCheckpointer()
checkpointer.save("/tmp/checkpoint", params)
# Log as a W&B artifact
artifact = wandb.Artifact("jax-model", type="model")
artifact.add_dir("/tmp/checkpoint")
wandb.log_artifact(artifact)
Watching for NaNs
JAX does not raise errors on NaN values by default—they propagate silently. Log a NaN check alongside your loss so you can spot training instability early:
wandb.log({
"train/loss": float(loss),
"train/loss_is_nan": bool(jnp.isnan(loss)),
})
You can also enable JAX’s debug NaN checking globally during development (at a performance cost):
from jax import config
config.update("jax_debug_nans", True)
Experiments
Runs
Metrics