Counterfactuals
The purpose of this vignette is to show how to generate counterfactual explanations from SDeMo
models.
using SpeciesDistributionToolkit
using CairoMakie
using PrettyTables
We will work on the demo data:
X, y = SDeMo.__demodata()
sdm = SDM(RawData, NaiveBayes, X, y)
variables!(sdm, [1, 12])
train!(sdm)
RawData → NaiveBayes → P(x) ≥ 0.477
We will focus on generating a counterfactual input, i.e. a set of hypothetical inputs that lead the model to predicting the other outcome. Internally, candidate points are generated using the Nelder-Mead algorithm, which works well enough but is not compatible with categorical data.
We will pick one prediction to flip:
inst = 6
6
And look at its outcome:
outcome = predict(sdm)[inst]
true
Our target is expressed in terms of the score we want the counterfactual to reach (and not in terms of true/false, this is very important):
target = outcome ? 0.9threshold(sdm) : 1.1threshold(sdm)
0.42889409290513314
The actual counterfactual is generated as (we only account for the relevant variables):
cf = [
counterfactual(
sdm,
instance(sdm, inst; strict = false),
target,
200.0;
threshold = false,
) for _ in 1:5
]
cf = hcat(cf...)
19×5 Matrix{Float64}:
11.5982 13.0826 10.9653 13.2026 12.7248
3.4 3.4 3.4 3.4 3.4
18.8 18.8 18.8 18.8 18.8
509.6 509.6 509.6 509.6 509.6
19.7 19.7 19.7 19.7 19.7
1.7 1.7 1.7 1.7 1.7
18.0 18.0 18.0 18.0 18.0
8.7 8.7 8.7 8.7 8.7
16.4 16.4 16.4 16.4 16.4
17.5 17.5 17.5 17.5 17.5
3.4 3.4 3.4 3.4 3.4
768.098 942.479 725.092 948.048 907.891
133.0 133.0 133.0 133.0 133.0
15.0 15.0 15.0 15.0 15.0
44.0 44.0 44.0 44.0 44.0
381.0 381.0 381.0 381.0 381.0
69.0 69.0 69.0 69.0 69.0
85.0 85.0 85.0 85.0 85.0
281.0 281.0 281.0 281.0 281.0
The last value (set to 200.0
here) is the learning rate, which usually needs to be tuned. The input for the observation we are interested in is, as well as five possible counterfactuals, are given in the following table:
pretty_table(
hcat(variables(sdm), instance(sdm, inst), cf[variables(sdm), :]);
alignment = [:l, :c, :c, :c, :c, :c, :c],
backend = Val(:markdown),
header = ["Variable", "Obs.", "C. 1", "C. 2", "C. 3", "C. 4", "C. 5"],
formatters = (ft_printf("%4.1f", [2, 3, 4, 5, 6]), ft_printf("%d", 1)),
)
Variable | Obs. | C. 1 | C. 2 | C. 3 | C. 4 | C. 5 |
---|---|---|---|---|---|---|
1 | 9.8 | 11.6 | 13.1 | 11.0 | 13.2 | 12.7248 |
12 | 947.0 | 768.1 | 942.5 | 725.1 | 948.0 | 907.891 |
We can check the prediction that would be made on all the counterfactuals:
predict(sdm, cf)
5-element BitVector:
0
0
0
0
0