Note
Click here to download the full example code
GridSearch optimization¶
The SciKit-GStat package can be connected to scikit-learn i.e. to use the model optimization sub-package. In this example, different options are compared and cross-validated. The Interface has two different options to evaluate a variogram model fit:
goodness of fit measures of the spatial model itself
cross-validation of the variogram by interpolating the observation points
Both options can use the RMSE, MSE and MAE as a metric.
import plotly
import plotly.graph_objects as go
import hydrobox
from hydrobox.data import pancake
from hydrobox.plotting import plotting_backend
plotting_backend('plotly')
Load sample data from the data sub-module
df = pancake()
First, a Variogram is estimated, which will fix all arguments that should not be evaluated by the Grid Search.
vario = hydrobox.geostat.variogram(
coordinates=df[['x', 'y']].values,
values=df.z.values,
maxlag=500,
bin_func='kmeans',
return_type='object'
)
The next step is to create a parameter grid, which specifies the value space for each parameter that should be checked. Here, we will try all combinations of different models and lag classes.
param_grid = {
'model': ('spherical', 'exponential', 'matern'),
'n_lags': (15, 20, 25, 30, 35)
}
First the model fit itself is evaluated and only the best parameter set will be returned
best_param = hydrobox.geostat.gridsearch(
param_grid=param_grid,
variogram=vario,
coordinates=None, # must be set if variogram is None
values=None, # must be set if variogram is None
score='rmse', # default
cross_validate=False, # evaluate model fit,
n_jobs=-1, # use parallel mode
return_type='best_param'
)
print(best_param)
Out:
{'model': 'spherical', 'n_lags': 35}
It is also possible to return the underlying
GridSearchCV
instance.
This class holds way more information than just the best parameter.
# reun the same Gridsearch, return the object
clf = hydrobox.geostat.gridsearch(
param_grid=param_grid,
variogram=vario,
coordinates=None, # must be set if variogram is None
values=None, # must be set if variogram is None
score='rmse', # default
cross_validate=False, # evaluate model fit,
n_jobs=-1, # use parallel mode
return_type='object'
)
# get the scores and their std
scores = clf.cv_results_['mean_test_score']
scores_std = clf.cv_results_['std_test_score']
x = list(range(len(scores)))
Plot the result
fig = go.Figure()
fig.add_trace(
go.Scatter(x=x, y=scores, mode='lines', line_color='#A3ACF7', name='RMSE score')
)
fig.add_trace(
go.Scatter(x=x, y=scores + scores_std, mode='lines', line_color='#BAC1F2', fill='tonexty', name='RMSE + std')
)
fig.add_trace(
go.Scatter(x=x, y=scores - scores_std, mode='lines', line_color='#BAC1F2', fill='tonexty', name='RMSE - std')
)
fig.update_layout(
template='plotly_white'
)
# show the plot
plotly.io.show(fig)
Total running time of the script: ( 1 minutes 23.717 seconds)