Added
- Adds an option for device-parallel evaluation of `BBOBFitness`.
- Implements fully `pmap`-compatible implementations of `OpenES`, `PGPE`, `Sep_CMA_ES` and `SNES`. Example: [`09_pmap_strategy.ipynb`](https://github.com/RobertTLange/evosax/blob/main/examples/09_pmap_strategy.ipynb):
python
set number of cpu devices for jax pmap
import os
num_devices = 4
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={num_devices}"
import jax
import jax.numpy as jnp
from flax import jax_utils
print(jax.devices())
from evosax.problems import BBOBFitness
from evosax.v2 import SNES
fn_name = "Sphere"
num_dims = 2
popsize = 64
rng = jax.random.PRNGKey(0)
problem = BBOBFitness(fn_name, num_dims=num_dims, n_devices=num_devices)
strategy = SNES(
popsize=popsize,
num_dims=num_dims,
sigma_init=0.1,
n_devices=num_devices,
maximize=False,
)
params = strategy.default_params.replace(init_min=-3.0, init_max=3.0)
params = jax_utils.replicate(params)
init_rng = jnp.tile(rng[None], (num_devices, 1))
state = jax.pmap(strategy.initialize)(init_rng, params)
print("Mean pre-update:", state.mean) (num_devices, num_dims)
rng, rng_a, rng_e = jax.random.split(rng, 3)
ask_rng = jax.random.split(rng_a, num_devices)
x, state = jax.pmap(strategy.ask, axis_name="device")(ask_rng, state, params)
print("Population shape:", x.shape) (num_devices, popsize/num_devices, num_dims)
fitness = problem.rollout(rng_e, x)
print("Fitness shape:", fitness.shape) (num_devices, popsize/num_devices)
state = jax.pmap(strategy.tell, axis_name="device")(x, fitness, state, params)
print("Mean post-update:", state.mean) (num_devices, num_dims)
- Added `DiffusionEvolution` based on [Zhang et al. (2024)](https://arxiv.org/pdf/2410.02543). Example: [`10_diffusion_evolution.ipynb`](https://github.com/RobertTLange/evosax/blob/main/examples/10_diffusion_evolution.ipynb)
- Added `SV_CMA_ES` ([Braun et al., 2024](https://arxiv.org/abs/2410.10390)) and `SV_OpenES` ([Liu et al., 2017](https://arxiv.org/abs/1704.02399)) as subpopulation ES with stein variational updates.
Big thanks to Cornelius Braun (cornelius-braun
) for adding the stein variational methods!