New functionalities
- Extended all GP and DKL models to vector-valued targets
- Added an ensemble mode to viDKL for more accurate predictions and better uncertainty estimates
- Added a 'batch update' mode to Thomson and UCB acquisition functions with ExactGP, vExactGP, and DKL models.
Usage examples:
Ensemble deep kernel learning for vector-valued targets and multi-modal inputs with a custom neural network:
(better to run it on GPU)
python3
import numpy as np
import gpax
import haiku as hk
import jax
Define a custom feature extractor for DKL
class MLP2(hk.Module):
"""Simple custom MLP"""
def __init__(self, embedim=2):
super().__init__()
self._embedim = embedim
def __call__(self, x):
x = hk.Linear(128)(x)
x = jax.nn.tanh(x)
x = hk.Linear(64)(x)
x = jax.nn.tanh(x)
x = hk.Linear(32)(x)
x = jax.nn.tanh(x)
x = hk.Linear(self._embedim)(x)
return x
Multi-modal high-dimensional inputs
X_train = np.random.randn(2, 32, 144) n_modes x n_samples x n_features
X_unmeasured = np.random.randn(2, 100, 144)
Vector-valued targets
y_train = np.random.randn(2, 32)
Initialize and run an ensemble of 10 DKL models
key, _ = gpax.utils.get_keys()
dkl = gpax.viDKL(input_dim=X_train.shape[-1], z_dim=2, kernel='RBF', nn=MLP2)
y_means, y_vars = dkl.fit_predict(
key, X_train, y_train, X_unmeasured,
n_models=10, num_steps=1000, step_size=0.005)
Average ensemble predictions
y_mean = y_means.mean(0)
y_var = y_vars.mean(0)
The next example illustrates a 'batch mode' for Thompson sampling with a fully Bayesian DKL. Here y_train are values of a scalar physical property measured in image patches X_train. The X_unmeasuredrepresents the image patches for which the physical property of interest has not been measured yet. The indices describe locations of image patches in the experimental field of view.
python3
Initialize DKL model
data_dim = X_train.shape[-1]
key1, key2 = gpax.utils.get_keys()
dkl = gpax.DKL(data_dim, z_dim=2, kernel='RBF')
Obtain posterior samples for model parameters
dkl.fit(key1, X_train, y_train, num_warmup=333, num_samples=333, num_chains=3, chain_method='vectorized')
Batch mode of UCB = alpha*mu + sqrt(beta*var)
Generate the next 5 points to probe that are maximally apart from each other
obj = gpax.acquisition.bUCB(
key2, dkl, X_unmeasured, indices=indices,
alpha=1, beta=0, n_restarts=10, batch_size=5)
next_points_idx = obj.argmax(-1)