Generating counterfactuals
The purpose of this vignette is to show how to generate counterfactual explanations from SDeMo models.
using SpeciesDistributionToolkit
using PrettyTablesWe will work on the demo data:
X, y, C = SDeMo.__demodata()
sdm = SDM(RawData, NaiveBayes, X, y)
variables!(sdm, [1, 12])
train!(sdm)☑️ RawData → NaiveBayes → P(x) ≥ 0.416We will focus on generating a counterfactual input, i.e. a set of hypothetical inputs that lead the model to predicting the other outcome. Internally, candidate points are generated using the Nelder-Mead algorithm, which works well enough but is not compatible with categorical data.
We will pick one prediction to flip:
inst = 66And look at its outcome:
outcome = predict(sdm)[inst]falseOur target is expressed in terms of the score we want the counterfactual to reach (and not in terms of true/false, this is very important):
target = outcome ? 0.9threshold(sdm) : 1.1threshold(sdm)0.45776404958196787The actual counterfactual is generated as (we only account for the relevant variables):
cf = [
counterfactual(
sdm,
instance(sdm, inst; strict = false),
target,
200.0;
threshold = false,
) for _ in 1:5
]
cf = hcat(cf...)19×5 Matrix{Float64}:
13.473 13.5045 13.513 13.6084 13.5036
5.1 5.1 5.1 5.1 5.1
0.251 0.251 0.251 0.251 0.251
537.3 537.3 537.3 537.3 537.3
25.85 25.85 25.85 25.85 25.85
5.65 5.65 5.65 5.65 5.65
20.2 20.2 20.2 20.2 20.2
13.05 13.05 13.05 13.05 13.05
22.25 22.25 22.25 22.25 22.25
22.25 22.25 22.25 22.25 22.25
8.85 8.85 8.85 8.85 8.85
1301.85 1312.55 1318.8 1361.57 1298.73
209.3 209.3 209.3 209.3 209.3
16.1 16.1 16.1 16.1 16.1
51.1 51.1 51.1 51.1 51.1
560.1 560.1 560.1 560.1 560.1
101.7 101.7 101.7 101.7 101.7
101.7 101.7 101.7 101.7 101.7
418.2 418.2 418.2 418.2 418.2The last value (set to 200.0 here) is the learning rate, which usually needs to be tuned. The input for the observation we are interested in is, as well as five possible counterfactuals, are given in the following table:
pretty_table(
hcat(variables(sdm), instance(sdm, inst), cf[variables(sdm), :]);
alignment = [:l, :c, :c, :c, :c, :c, :c],
backend = :markdown,
column_labels = ["Variable", "Obs.", "C. 1", "C. 2", "C. 3", "C. 4", "C. 5"],
formatters = [fmt__printf("%4.1f", [2, 3, 4, 5, 6]), fmt__printf("%d", [1])],
)| Variable | Obs. | C. 1 | C. 2 | C. 3 | C. 4 | C. 5 |
|---|---|---|---|---|---|---|
| 1 | 15.2 | 13.5 | 13.5 | 13.5 | 13.6 | 13.5036 |
| 12 | 1318.0 | 1301.9 | 1312.6 | 1318.8 | 1361.6 | 1298.73 |
We can check the prediction that would be made on all the counterfactuals:
predict(sdm, cf)5-element BitVector:
1
1
1
1
1Related documentation
SDeMo.counterfactual Function
counterfactual(model::AbstractSDM, x::Vector{T}, yhat, λ; maxiter=100, minvar=5e-5, kwargs...) where {T <: Number}Generates one counterfactual explanation given an input vector x, and a target rule to reach yhat. The learning rate is λ. The maximum number of iterations used in the Nelder-Mead algorithm is maxiter, and the variance improvement under which the model will stop is minvar. Other keywords are passed to predict.