Skip to content

Training SDMs with SDeMo

The purpose of this vignette is to take a whistle-stop tour of the what the package has to offer. We will see how to train a model, cross-validate it, select variables, tune hyper-parameters, use bootstraping to measure uncertainty, show partial response curves, calculate varialble importances through bootstraping and Shapley values, and generate counterfactual inputs.

Mapping the outputs

The functions to interact with raster data, and present the outputs of models as maps, are implemented as part of SpeciesDistributionToolkit (SDT). There is a follow-up tutorial that focuses on visualisation of SDM outputs.

This vignette is very terse and is not an introduction to using or interpreting these models!

Setting up the environment

julia
using SpeciesDistributionToolkit
using Statistics
using CairoMakie
using PrettyTables

The package comes with a series of demonstration data, that represent the presences and absences of Sitta whiteheadi at about 1500 locations in Corsica, with 19 bioclim environmental variables:

julia
X, y = SDeMo.__demodata()
size(X)
(19, 1484)

Setting up the model

We will start with an initial model that uses a PCA to transform the data, and then a Naive Bayes Classifier for the classification. Note that this is the partial syntax where we use the default threshold, and all the variables:

julia
sdm = SDM(MultivariateTransform{PCA}, NaiveBayes, X, y)
PCATransform → NaiveBayes → P(x) ≥ 0.5

Transformations

The transformations will by default only account for the presences. This is motivated by the fact that absences will often be pseudo-absences, and therefore should not be used to transform the data. There are ways to use absences to transform the data as well, which are presented in the SDeMo manual.

Initial cross-validation

Throughout this demonstration, we will use the same splits between validation and training data, with 10 folds:

julia
folds = kfold(sdm; k = 10);

Using the same splits is important because it will provide a fair basis of comparison for the different models.

Before starting any manipulation, we can start cross-validating the model: this will return a named tuple with the confusion matrices for validation and training data:

julia
cv = crossvalidate(sdm, folds);

The reason why we return the confusion matrices as opposed to directly measuring statistics is that we have more freedom to manipulate the data, without needing to re-run the analysis. For example, we can get the average of a handful of measures on this first step of cross validation:

julia
measures = [mcc, balancedaccuracy, ppv, npv, trueskill, markedness]
cvresult = [measure(set) for measure in measures, set in cv]
pretty_table(
    hcat(string.(measures), cvresult);
    alignment = [:l, :c, :c],
    backend = Val(:markdown),
    header = ["Measure", "Validation", "Training"],
    formatters = ft_printf("%5.3f", [2, 3]),
)
MeasureValidationTraining
mcc0.5010.502
balancedaccuracy0.7490.748
ppv0.7300.730
npv0.7760.776
trueskill0.4980.497
markedness0.5060.506

By default, calling a function to measure model performance (the full list is in the manual) on a series of confusion matrices will return the average. Adding the true argument returns a tuple with the 95% CI:

julia
mcc(cv.validation, true)
(0.5014894802983206, 0.04048032968194452)

If we want to get a simple idea of what the MCC is for the validation data for all folds, we can use:

julia
mcc.(cv.validation)
10-element Vector{Float64}:
 0.411601475294833
 0.5674044484888594
 0.4655712348560116
 0.47520253719491684
 0.4012480832119629
 0.5132755881718536
 0.5282350793547591
 0.4949922840967364
 0.6023685038544081
 0.5549955684588657

The associated confidence interval is

julia
ci(cv.validation, mcc)
0.04048032968194452

Note that cross-validation does not train the model. This is meant to convey the idea that cross-validation is letting us know whether the model is good enough to be trained. If we want to forgo variable selection, we can train the model with

julia
train!(sdm)
PCATransform → NaiveBayes → P(x) ≥ 0.39

Variable selection

We will now select variables using forward selection, but with the added constraint that the first variable (annual mean temperature) must be included:

julia
forwardselection!(sdm, folds, [1])
PCATransform → NaiveBayes → P(x) ≥ 0.322

This operation will retrain the model. We can now look at the list of selected variables:

julia
variables(sdm)
6-element Vector{Int64}:
  1
  3
  6
  4
 15
  2

Is this making a difference in terms of cross-validation?

julia
cv2 = crossvalidate(sdm, folds)
measures = [mcc, balancedaccuracy, ppv, npv, trueskill, markedness]
cvresult = [measure(set) for measure in measures, set in cv2]
pretty_table(
    hcat(string.(measures), cvresult);
    alignment = [:l, :c, :c],
    backend = Val(:markdown),
    header = ["Measure", "Validation", "Training"],
    formatters = ft_printf("%5.3f", [2, 3]),
)
MeasureValidationTraining
mcc0.7880.785
balancedaccuracy0.8980.896
ppv0.8420.842
npv0.9380.937
trueskill0.7960.792
markedness0.7800.779

Quite clearly! Before thinking about the relative importance of variables, we will take a look at the thresold.

Moving threshold classification

The crossvalidate function comes with an optional argument to specify the threshold between a positive and negative outcome. We can use this to figure out the best possible threshold:

julia
thresholds = LinRange(0.01, 0.99, 200)
cv3 = [crossvalidate(sdm, folds; thr = t) for t in thresholds];

We can plot the simplified version of this analysis (averaged across all folds for each value of the threshold):

Code for the figure
julia
f = Figure()
ax = Axis(f[1, 1]; aspect = 1, xlabel = "Threshold", ylabel = "MCC")
scatter!(ax, thresholds, [mean(mcc.(s.validation)) for s in cv3])
ax2 = Axis(f[1, 2]; aspect = 1, xlabel = "Precision", ylabel = "Recall")
scatter!(
    ax2,
    [mean(SDeMo.precision.(s.validation)) for s in cv3],
    [mean(SDeMo.recall.(s.validation)) for s in cv3],
)
xlims!(ax, 0.0, 1.0)
ylims!(ax, 0.0, 1.0)
xlims!(ax2, 0.0, 1.0)
ylims!(ax2, 0.0, 1.0)

For reference, the threshold at the end of this process is:

julia
threshold(sdm)
0.3215740401355837

Variable importance

We can evaluate the importance of each variable by bootstrap, where the values of this variable are randomized, and the absolute value of the difference in model performance is returned (averaged across all folds):

julia
varimp = variableimportance(sdm, folds)
6-element Vector{Float64}:
 0.27323841649892033
 0.07389908443010518
 0.28296741279358817
 0.17955314447782553
 0.03419082453433911
 0.002253220156847351

In relative terms, this is:

julia
pretty_table(
    hcat(variables(sdm), varimp ./ sum(varimp));
    alignment = [:l, :c],
    backend = Val(:markdown),
    header = ["Variable", "Importance"],
    formatters = (ft_printf("%5.3f", 2), ft_printf("%d", 1)),
)
VariableImportance
10.323
30.087
60.334
40.212
150.040
20.003

Partial response curve

Another way to look at the effect of variables is to use the partial response curves. For example, we can look at the predictions of the model between 5 and 15 degrees:

julia
prx, pry = partialresponse(sdm, 1, LinRange(5.0, 15.0, 100); threshold = false);

Note that we use threshold=false to make sure that we look at the score that is returned by the classifier, and not the thresholded version.

Code for the figure
julia
f = Figure()
ax = Axis(f[1, 1]; xlabel = "BIO1", ylabel = "Partial response")
lines!(ax, prx, pry; color = :black)
hlines!(ax, [threshold(sdm)]; color = :red, linestyle = :dash)

We can also show the response surface using two variables:

julia
prx, pry, prz = partialresponse(sdm, variables(sdm)[1:2]..., (100, 100); threshold = false);

Note that the last element returned in this case is a two-dimensional array, as it makes sense to visualize the result as a heatmap. Although the idea of a the partial response curves generalizes to more than two dimensions, it is not supported by the package.

Code for the figure
julia
f = Figure()
ax = Axis(f[1, 1]; xlabel = "BIO$(variables(sdm)[1])", ylabel = "BIO$(variables(sdm)[2])")
cm = heatmap!(prx, pry, prz; colormap = :Oranges)
Colorbar(f[1, 2], cm)

Inflated partial responses

Inflated partial responses replace the average value by other values drawn from different quantiles of the variables:

Code for the figure
julia
f = Figure()
ax = Axis(f[1, 1])
prx, pry = partialresponse(sdm, 1; inflated = false, threshold = false)
for i in 1:200
    ix, iy = partialresponse(sdm, 1; inflated = true, threshold = false)
    lines!(ax, ix, iy; color = (:grey, 0.5))
end
lines!(ax, prx, pry; color = :black, linewidth = 4)

Measuring uncertainty with bagging

We can wrap our model into an homogeneous ensemble:

julia
ensemble = Bagging(sdm, 20)
train!(ensemble)
{PCATransform → NaiveBayes → P(x) ≥ 0.322} × 20

Ensemble models can be used in the same way as regular models, but take an extra keyword consensus to indicate how the results from each model should be reconciled. By default, this is the median, but for uncertainty, iqr is a great choice. Note that we use threshold=false because we want to look at the distribution of the score, not the boolean output:

julia
uncert = predict(ensemble; consensus = iqr, threshold = false);

Code for the figure
julia
hist(uncert; color = :grey, axis = (; xlabel = "Uncertainty (IQR)"))

Heterogeneous ensembles

We can setup an heterogeneous ensemble model by passing several SDMs to Ensemble:

julia
m1 = SDM(MultivariateTransform{PCA}, NaiveBayes, X, y)
m2 = SDM(RawData, BIOCLIM, X, y)
m3 = SDM(MultivariateTransform{PCA}, BIOCLIM, X, y)
variables!(m2, [1, 12])
hm = Ensemble(m1, m2, m3)
An ensemble model with:
	 PCATransform → NaiveBayes → P(x) ≥ 0.5
	 RawData → BIOCLIM → P(x) ≥ 0.01
	 PCATransform → BIOCLIM → P(x) ≥ 0.01

We can train this model in the same way:

julia
train!(hm)
An ensemble model with:
	 PCATransform → NaiveBayes → P(x) ≥ 0.39
	 RawData → BIOCLIM → P(x) ≥ 0.11
	 PCATransform → BIOCLIM → P(x) ≥ 0.102

And get predictions:

julia
predict(hm; consensus = median, threshold = false)[1:10]
10-element Vector{Float64}:
 0.5181674565560821
 0.5023696682464455
 0.39178515007898895
 0.6350710900473934
 0.20853080568720372
 0.639106860018156
 0.05055292259083721
 0.2559241706161137
 0.18957345971563977
 0.593996840442338

Note that the ensemble models (Bagging and Ensemble) are supported by the explain, partialresponse, and counterfactual functions.

Explaining predictions

We can perform the (MCMC version of) Shapley values measurement, using the explain method:

julia
[explain(sdm, v; observation = 3) for v in variables(sdm)]
6-element Vector{Float64}:
 -0.19
  0.04
 -0.18
 -0.005
  0.02
 -0.005

These values are returned as the effect of this variable's value on the average prediction for this observation.

We can also produce a figure that looks like the partial response curve, by showing the effect on a variable on each training instance:

Code for the figure
julia
f = Figure()
ax = Axis(f[1, 1]; xlabel = "BIO1", ylabel = "Effect on the average prediction")
scatter!(ax, features(sdm, 1), explain(sdm, 1; threshold = true); color = :purple)

Counterfactuals

In the final example, we will focus on generating a counterfactual input, i.e. a set of hypothetical inputs that lead the model to predicting the other outcome. Internally, candidate points are generated using the Nelder-Mead algorithm, which works well enough but is not compatible with categorical data.

We will pick one prediction to flip:

julia
inst = 4
4

And look at its outcome:

julia
outcome = predict(sdm)[inst]
true

Our target is expressed in terms of the score we want the counterfactual to reach (and not in terms of true/false, this is very important):

julia
target = outcome ? 0.9threshold(sdm) : 1.1threshold(sdm)
0.28941663612202534

The actual counterfactual is generated as:

julia
cf = counterfactual(
    sdm,
    instance(sdm, inst; strict = false),
    target,
    200.0;
    threshold = false,
)
19-element Vector{Float64}:
  13.228503368671777
   3.812552349998173
  20.8990908875656
 505.0859650219775
  16.7
  -0.31810197778322247
  18.3
   5.6000000000000005
  13.200000000000001
  14.299999999999997
   0.30000000000000004
 924.0
 135.0
  14.0
  47.974493367547375
 382.0
  61.0
  72.0
 283.0

The last value (set to 200.0 here) is the learning rate, which usually needs to be tuned. The countefactual input for the observation we are interested in is:

julia
pretty_table(
    hcat(variables(sdm), instance(sdm, inst), cf[variables(sdm)]);
    alignment = [:l, :c, :c],
    backend = Val(:markdown),
    header = ["Variable", "Obs.", "Counterf."],
    formatters = (ft_printf("%4.1f", [2, 3]), ft_printf("%d", 1)),
)
VariableObs.Counterf.
16.613.2
321.020.9
6-1.6-0.3
4505.1505.1
1548.048.0
23.83.8

We can check the prediction that would be made on the counterfactual:

julia
predict(sdm, cf)
false

Conclusion

This vignette offered a very quick overview of what the SDeMo package does. The integration with the rest of the SpeciesDistributionToolkit ecosystem is detailed in the main documentation.