Skip to main content

Documentation Index

Fetch the complete documentation index at: https://wb-21fd5541-kb-refresh.mintlify.app/llms.txt

Use this file to discover all available pages before exploring further.

W&B does not have a JAX-specific integration, but works well with JAX training loops through direct calls to wandb.log(). The main consideration is that JAX device arrays must be converted to Python scalars before logging. Basic setup
import jax
import jax.numpy as jnp
import wandb

wandb.init(project="my-jax-project", config={
    "learning_rate": 1e-3,
    "batch_size": 64,
    "epochs": 50,
})

# Access config
lr = wandb.config.learning_rate
Converting JAX arrays before logging JAX values returned from jit-compiled functions are device arrays. Call .item() or wrap in float() before passing to wandb.log():
@jax.jit
def train_step(params, batch):
    loss, grads = jax.value_and_grad(loss_fn)(params, batch)
    params = update_params(params, grads)
    return params, loss

for step, batch in enumerate(dataloader):
    params, loss = train_step(params, batch)

    # loss is a JAX array — convert before logging
    wandb.log({"train/loss": float(loss)}, step=step)
Passing a JAX device array directly to wandb.log() without conversion may cause serialization errors or silently log incorrect values. Logging at epoch boundaries For validation metrics that require a full pass over the validation set, aggregate in Python and log once per epoch:
for epoch in range(num_epochs):
    # training loop ...

    val_losses = []
    for val_batch in val_loader:
        val_loss = eval_step(params, val_batch)
        val_losses.append(float(val_loss))

    wandb.log({
        "epoch": epoch,
        "val/loss": sum(val_losses) / len(val_losses),
    })
Logging model checkpoints as artifacts Save JAX/Flax model parameters with orbax or flax.serialization and log them as W&B artifacts:
import orbax.checkpoint as ocp

# Save checkpoint to disk
checkpointer = ocp.StandardCheckpointer()
checkpointer.save("/tmp/checkpoint", params)

# Log as a W&B artifact
artifact = wandb.Artifact("jax-model", type="model")
artifact.add_dir("/tmp/checkpoint")
wandb.log_artifact(artifact)
Watching for NaNs JAX does not raise errors on NaN values by default—they propagate silently. Log a NaN check alongside your loss so you can spot training instability early:
wandb.log({
    "train/loss": float(loss),
    "train/loss_is_nan": bool(jnp.isnan(loss)),
})
You can also enable JAX’s debug NaN checking globally during development (at a performance cost):
from jax import config
config.update("jax_debug_nans", True)

Experiments Runs Metrics