FVGP Single Task Notebook

In this notebook we will go through a few features of fvGP. We will be primarily concerned with regression over a single dimension output and single task. See the multiple_task_test_notebook.ipynb for single dimension and multiple task example. The extension to multiple dimensions is straight forward.

Import fvgp and relevant libraries

import fvgp
from fvgp import gp
import numpy as np
import matplotlib.pyplot as plt

Defining some input data and testing points

def function(x):
    return np.sin(1.1 * x)+np.cos(0.5 * x)
x_data = np.linspace(-2*np.pi, 10*np.pi,50).reshape(-1,1)
y_data = function(x_data)
x_pred = np.linspace(-2*np.pi, 10 * np.pi, 1000)

Setting up the fvgp single task object

NOTE: The input data need to be given in the form (N x input_space_dim). The output can either be a N array or N x 1 array where N is the number of data points. See help(gp.GP) for more information.

obj = gp.GP(1, x_data,y_data, init_hyperparameters = np.array([10,10]),use_inv = False)

Training our gaussian process regression on given data

hyper_param_bounds = np.array([[0.001, 5.],[ 0.001, 100]])
##this will block the main thread, even if you use "hgdl", another option is "global" or "local"
obj.train(hyper_param_bounds, method = "hgdl")
/home/docs/checkouts/readthedocs.org/user_builds/fvgp/envs/latest/lib/python3.8/site-packages/distributed/dashboard/core.py:20: UserWarning: 
Dask needs bokeh >= 2.4.2, < 3 for the dashboard.
You have bokeh==3.0.3.
Continuing without the dashboard.
  warnings.warn(
fbfs  {'x': array([[2.22782167e+00, 3.53027244e+00],
       [2.22782207e+00, 3.53027263e+00],
       [2.22782301e+00, 3.53027312e+00],
       [2.22782206e+00, 3.53027258e+00],
       [2.22782110e+00, 3.53027206e+00],
       [2.22782132e+00, 3.53027223e+00],
       [2.22782176e+00, 3.53027277e+00],
       [2.22782386e+00, 3.53027350e+00],
       [2.22782245e+00, 3.53027239e+00],
       [2.22782076e+00, 3.53027141e+00],
       [2.22782232e+00, 3.53027218e+00],
       [2.22782218e+00, 3.53027189e+00],
       [2.22782631e+00, 3.53027515e+00],
       [2.22782070e+00, 3.53027070e+00],
       [2.22782700e+00, 3.53027456e+00],
       [2.22781710e+00, 3.53026821e+00],
       [2.22782494e+00, 3.53027632e+00],
       [2.22782843e+00, 3.53027841e+00],
       [2.22782443e+00, 3.53027651e+00],
       [2.22783047e+00, 3.53027621e+00],
       [2.22781883e+00, 3.53026798e+00],
       [2.22781064e+00, 3.53026595e+00],
       [2.22781152e+00, 3.53026388e+00],
       [2.22781943e+00, 3.53027409e+00],
       [2.22780770e+00, 3.53026376e+00],
       [2.22783676e+00, 3.53028064e+00],
       [2.22780885e+00, 3.53026606e+00],
       [2.22783748e+00, 3.53028047e+00],
       [2.22783406e+00, 3.53027726e+00],
       [2.22782741e+00, 3.53027203e+00],
       [2.22782174e+00, 3.53026796e+00],
       [2.22782995e+00, 3.53028174e+00],
       [2.22782443e+00, 3.53026918e+00],
       [2.22784400e+00, 3.53028537e+00],
       [2.22783290e+00, 3.53027487e+00],
       [2.22781520e+00, 3.53026327e+00],
       [2.22780831e+00, 3.53026860e+00],
       [2.22780816e+00, 3.53025891e+00],
       [2.22780380e+00, 3.53025681e+00],
       [2.22783618e+00, 3.53027642e+00],
       [2.22781184e+00, 3.53027173e+00],
       [2.22781639e+00, 3.53027497e+00],
       [2.22782341e+00, 3.53027975e+00],
       [2.22782618e+00, 3.53026877e+00],
       [2.22781263e+00, 3.53027384e+00],
       [2.22780130e+00, 3.53025303e+00],
       [2.22783733e+00, 3.53029009e+00],
       [2.22782524e+00, 3.53028334e+00],
       [2.22781143e+00, 3.53025703e+00],
       [2.22782260e+00, 3.53028242e+00],
       [2.22783661e+00, 3.53029131e+00],
       [2.22779801e+00, 3.53024793e+00],
       [2.22786024e+00, 3.53028764e+00],
       [2.22780881e+00, 3.53027701e+00],
       [2.22777684e+00, 3.53025277e+00],
       [2.22782386e+00, 3.53025916e+00],
       [2.22787809e+00, 3.53029998e+00],
       [2.22783630e+00, 3.53029659e+00],
       [2.22775718e+00, 3.53023545e+00],
       [2.22789131e+00, 3.53031678e+00],
       [2.22781313e+00, 3.53028363e+00],
       [2.22787317e+00, 3.53031913e+00],
       [2.22777757e+00, 3.53022919e+00],
       [2.22795255e+00, 3.53036619e+00],
       [2.22800771e+00, 3.53038149e+00],
       [2.22814774e+00, 3.53047206e+00],
       [2.22835032e+00, 3.53056498e+00],
       [2.22818531e+00, 3.53040609e+00],
       [2.22670353e+00, 3.52951888e+00],
       [9.00632238e-01, 1.00000000e-03],
       [9.00632251e-01, 1.00000000e-03],
       [9.00634936e-01, 1.00000000e-03],
       [9.00639048e-01, 1.00000000e-03],
       [9.00644167e-01, 1.00000000e-03]]), 'f(x)': array([25.31511552, 25.31511552, 25.31511552, 25.31511552, 25.31511552,
       25.31511552, 25.31511552, 25.31511552, 25.31511552, 25.31511552,
       25.31511552, 25.31511552, 25.31511552, 25.31511552, 25.31511552,
       25.31511552, 25.31511552, 25.31511552, 25.31511552, 25.31511552,
       25.31511552, 25.31511552, 25.31511552, 25.31511552, 25.31511552,
       25.31511552, 25.31511552, 25.31511552, 25.31511552, 25.31511552,
       25.31511552, 25.31511552, 25.31511552, 25.31511552, 25.31511552,
       25.31511552, 25.31511552, 25.31511552, 25.31511552, 25.31511552,
       25.31511552, 25.31511552, 25.31511552, 25.31511552, 25.31511552,
       25.31511552, 25.31511552, 25.31511552, 25.31511552, 25.31511552,
       25.31511552, 25.31511552, 25.31511552, 25.31511552, 25.31511552,
       25.31511552, 25.31511552, 25.31511552, 25.31511552, 25.31511552,
       25.31511552, 25.31511552, 25.31511552, 25.31511552, 25.31511553,
       25.31511554, 25.3151156 , 25.3151156 , 25.31511586, 68.54063634,
       68.54063634, 68.54063634, 68.54063634, 68.54063634]), 'classifier': ['degenerate', 'degenerate', 'degenerate', 'optimum', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'degenerate', 'optimum', 'degenerate', 'optimum', 'degenerate', 'degenerate', 'degenerate'], 'Hessian eigvals': array([[nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan],
       [nan, nan]]), 'df/dx': array([[-1.41500222e+00,  1.40317701e+00],
       [-3.42238625e+01,  4.07939944e+01],
       [ 1.67977468e-01, -1.49363491e-01],
       [-4.57803981e-08, -1.36911893e-07],
       [-2.46444608e-01,  1.46629328e-01],
       [-4.44011971e-01,  4.45740315e-01],
       [-5.49718786e+00,  7.52997713e+00],
       [ 1.37163253e-01, -1.51391034e-01],
       [ 3.97892172e+00, -5.94878240e+00],
       [ 1.23664321e-01, -3.00843371e-01],
       [ 4.16774599e+00, -6.38062053e+00],
       [ 2.50191142e+00, -4.06120264e+00],
       [ 3.06639286e-02, -1.52594553e-02],
       [ 2.79694073e-01, -4.91663950e-01],
       [ 8.95564054e-02, -1.06463073e-01],
       [ 3.03681333e-02, -7.35072548e-02],
       [-1.24357697e-01,  2.19064527e-01],
       [-2.80126015e-02,  6.17458675e-02],
       [-1.46583945e-05,  2.48298554e-05],
       [ 4.62919184e-02, -5.26350443e-02],
       [ 1.17448603e-01, -2.09567666e-01],
       [-1.45411414e-02,  9.02720815e-03],
       [ 1.00685615e-02, -2.78026568e-02],
       [-6.06178417e-01,  8.93120967e-01],
       [-8.46413686e-03,  2.58169929e-03],
       [ 1.59619924e-05, -1.31749135e-05],
       [-2.22216822e-02,  2.20440688e-02],
       [ 1.70650301e-02, -1.63517869e-02],
       [ 3.83751153e-02, -4.56307259e-02],
       [ 2.47107404e-01, -3.51973326e-01],
       [ 3.33642173e-01, -5.31288762e-01],
       [-3.94362490e-02,  7.31591477e-02],
       [ 4.82765708e-01, -7.37011785e-01],
       [ 1.89832075e-05, -1.17512206e-05],
       [ 7.26713635e-02, -9.65153493e-02],
       [ 5.35191266e-02, -9.56906110e-02],
       [-4.69234713e-02,  5.94831867e-02],
       [ 1.60198884e-02, -3.28623381e-02],
       [ 7.59897356e-03, -1.85479666e-02],
       [ 4.76301291e-02, -6.17745662e-02],
       [-1.00701599e-01,  1.37983802e-01],
       [-2.79233303e-01,  4.07581190e-01],
       [-1.91043169e-01,  3.05351704e-01],
       [ 3.50065558e-01, -5.25864986e-01],
       [-1.45033548e-01,  2.06003630e-01],
       [ 9.42221211e-03, -2.02627456e-02],
       [-1.99799336e-02,  3.72567483e-02],
       [-1.10353911e-01,  1.78968511e-01],
       [ 3.65087892e-02, -6.40277443e-02],
       [-1.59488149e-01,  2.52071367e-01],
       [-2.41911387e-02,  4.33920123e-02],
       [ 1.02401301e-02, -2.03676867e-02],
       [ 1.18392690e-02, -1.39912705e-02],
       [-1.16641900e-01,  1.69342025e-01],
       [-8.38095089e-03,  9.33023225e-03],
       [ 1.31150372e-01, -2.04832399e-01],
       [ 5.28443760e-03, -5.36339242e-03],
       [-2.89816346e-02,  4.94154686e-02],
       [-2.69145078e-03,  1.83493469e-03],
       [ 1.27473343e-03,  9.51270595e-05],
       [-1.41358973e-01,  2.13351780e-01],
       [-3.38987574e-03,  7.53617579e-03],
       [ 4.85997927e-03, -1.00484400e-02],
       [-1.74072969e-05,  1.02570268e-03],
       [ 7.67933618e-04, -4.50281122e-04],
       [ 2.90669429e-04, -9.81866358e-05],
       [ 5.59186663e-04, -4.41166589e-04],
       [ 1.06541099e-03, -1.28907234e-03],
       [-3.05441441e-04, -4.52663148e-04],
       [ 8.35895888e+03,  0.00000000e+00],
       [ 6.33731610e-06,  0.00000000e+00],
       [ 3.04033186e+00,  0.00000000e+00],
       [ 1.14838093e+00,  0.00000000e+00],
       [ 6.46534448e-01,  0.00000000e+00]]), '|df/dx|': array([1.99277119e+00, 5.32486877e+01, 2.24779630e-01, 1.44363123e-07,
       2.86766638e-01, 6.29151062e-01, 9.32306977e+00, 2.04286571e-01,
       7.15680307e+00, 3.25268502e-01, 7.62118265e+00, 4.77000290e+00,
       3.42509488e-02, 5.65652025e-01, 1.39121298e-01, 7.95332636e-02,
       2.51900980e-01, 6.78030825e-02, 2.88338386e-05, 7.00955747e-02,
       2.40234846e-01, 1.71153522e-02, 2.95696408e-02, 1.07940601e+00,
       8.84911204e-03, 2.06969454e-05, 3.13008647e-02, 2.36346395e-02,
       5.96222493e-02, 4.30054986e-01, 6.27363410e-01, 8.31112425e-02,
       8.81049999e-01, 2.23260690e-05, 1.20815312e-01, 1.09640275e-01,
       7.57631946e-02, 3.65591314e-02, 2.00442377e-02, 7.80046552e-02,
       1.70822544e-01, 4.94058361e-01, 3.60190443e-01, 6.31727693e-01,
       2.51936948e-01, 2.23462959e-02, 4.22760338e-02, 2.10256305e-01,
       7.37051133e-02, 2.98289195e-01, 4.96797537e-02, 2.27969938e-02,
       1.83282279e-02, 2.05626005e-01, 1.25416734e-02, 2.43221569e-01,
       7.52935980e-03, 5.72872035e-02, 3.25743651e-03, 1.27827794e-03,
       2.55932298e-01, 8.26348613e-03, 1.11620135e-02, 1.02585038e-03,
       8.90210723e-04, 3.06805040e-04, 7.12262369e-04, 1.67236601e-03,
       5.46075452e-04, 8.35895888e+03, 6.33731610e-06, 3.04033186e+00,
       1.14838093e+00, 6.46534448e-01]), 'success': True}
/home/docs/checkouts/readthedocs.org/user_builds/fvgp/envs/latest/lib/python3.8/site-packages/scipy/optimize/_minimize.py:554: RuntimeWarning: Method L-BFGS-B does not use Hessian information (hess).
  warn('Method %s does not use Hessian information (hess).' % method,

Looking the posterior mean at the test points

post_mean= obj.posterior_mean(x_pred.reshape(-1,1))["f(x)"]
post_var= obj.posterior_covariance(x_pred.reshape(-1,1))["v(x)"]
plt.plot(x_pred, post_mean, label='gp interpolation')
plt.scatter(x_data, y_data, label='data')
plt.plot(x_pred,function(x_pred), label = 'ground truth')
plt.fill_between(x_pred, post_mean + 3.0 *np.sqrt(post_var),post_mean - 3.0 * np.sqrt(post_var), color = 'grey', alpha = 0.5)
plt.legend()
<matplotlib.legend.Legend at 0x7fc609effe80>
../_images/a95f1b1fa92acc3ff1285c0cdf6f58ced05e0efd1405b6448fa6559127919f24.png

Training Asynchronously

obj = gp.GP(1, x_data,y_data, init_hyperparameters = np.array([10,10]),
                            variances = np.zeros(y_data.reshape(-1,1).shape))
hyper_param_bounds = np.array([[0.0001, 100], [ 0.0001, 100]])
async_obj = obj.train_async(hyper_param_bounds)
/home/docs/checkouts/readthedocs.org/user_builds/fvgp/envs/latest/lib/python3.8/site-packages/distributed/dashboard/core.py:20: UserWarning: 
Dask needs bokeh >= 2.4.2, < 3 for the dashboard.
You have bokeh==3.0.3.
Continuing without the dashboard.
  warnings.warn(

Updating asynchronously

Updates hyperparameters to current optimization values

obj.update_hyperparameters(async_obj)
array([10, 10])

Killing training

obj.kill_training(async_obj)

Looking at the posterior mean at the test points

post_mean= obj.posterior_mean(x_pred.reshape(-1,1))['f(x)']
plt.plot(x_pred, post_mean, label='interpolation')
plt.scatter(x_data, y_data, label='data')
plt.plot(x_pred, function(x_pred), label='ground truth')
plt.fill_between(x_pred, post_mean + 3.0 *np.sqrt(post_var),post_mean - 3.0 * np.sqrt(post_var), color = 'grey', alpha = 0.5)
plt.legend()
<matplotlib.legend.Legend at 0x7fc609e1e280>
../_images/598d0b2d6f145f0f080999a722ad2b849965a95be31e00a27c45ca94c388d0e7.png

Custom Kernels

def kernel_l1(x1,x2, hp, obj):
    ################################################################
    ###standard anisotropic kernel in an input space with l1########
    ################################################################
    d1 = abs(np.subtract.outer(x1[:,0],x2[:,0])) 
    return hp[0] * np.exp(-d1/hp[1])
obj = gp.GP(1, x_data,y_data,
                init_hyperparameters = np.array([10,10]),
                variances = np.zeros(y_data.shape),
                gp_kernel_function = kernel_l1)

Training our gaussian process regression on given data

hyper_param_bounds = np.array([[0.0001, 1000],[ 0.0001, 1000]])
obj.train(hyper_param_bounds)

Looking the posterior mean at the test points

post_mean= obj.posterior_mean(x_pred.reshape(-1,1))["f(x)"]
plt.plot(x_pred, post_mean, label='interpolation')
plt.scatter(x_data, y_data, label='data')
plt.plot(x_pred, function(x_pred), label='ground truth')
plt.fill_between(x_pred, post_mean + 3.0 *np.sqrt(post_var),post_mean - 3.0 * np.sqrt(post_var), color = 'grey', alpha = 0.5)
plt.legend()
<matplotlib.legend.Legend at 0x7fc574c677c0>
../_images/384346c8b72ee7325de41b882f0c8a13e1aa09416b22309ec1351feb5daa7053.png

Prior Mean Functions

NOTE: The prior mean function must return a 1d vector, e.g., (100,)

def example_mean(x,hyperparameters,gp_obj):
    return np.ones(len(x))
obj = gp.GP(1, x_data,y_data,init_hyperparameters = np.array([10,10]),
                            variances = np.zeros(y_data.shape),
                            gp_mean_function = example_mean)

Training our gaussian process regression on given data

hyper_param_bounds = np.array([[0.0001, 1000],[ 0.0001, 1000]])
obj.train(hyper_param_bounds)

Looking the posterior mean at the test points

post_mean= obj.posterior_mean(x_pred.reshape(-1,1))["f(x)"]
plt.plot(x_pred, post_mean, label='interpolation')
plt.scatter(x_data, y_data, label='data')
plt.plot(x_pred, function(x_pred), label='ground truth')
plt.fill_between(x_pred, post_mean + 3.0 *np.sqrt(post_var),post_mean - 3.0 * np.sqrt(post_var), color = 'grey', alpha = 0.5)
plt.legend()
<matplotlib.legend.Legend at 0x7fc574ead490>
../_images/baf7c425c109957accc3a02be8bfe337b6aa5efdc1ffe64f1beb1ab88c2eb401.png