Skip to content

Generating counterfactuals

The purpose of this vignette is to show how to generate counterfactual explanations from SDeMo models.

julia
using SpeciesDistributionToolkit
using PrettyTables

We will work on the demo data:

julia
X, y, C = SDeMo.__demodata()
sdm = SDM(RawData, NaiveBayes, X, y)
variables!(sdm, [1, 12])
train!(sdm)
☑️  RawData → NaiveBayes → P(x) ≥ 0.416

We will focus on generating a counterfactual input, i.e. a set of hypothetical inputs that lead the model to predicting the other outcome. Internally, candidate points are generated using the Nelder-Mead algorithm, which works well enough but is not compatible with categorical data.

We will pick one prediction to flip:

julia
inst = 6
6

And look at its outcome:

julia
outcome = predict(sdm)[inst]
false

Our target is expressed in terms of the score we want the counterfactual to reach (and not in terms of true/false, this is very important):

julia
target = outcome ? 0.9threshold(sdm) : 1.1threshold(sdm)
0.45776404958196787

The actual counterfactual is generated as (we only account for the relevant variables):

julia
cf = [
    counterfactual(
        sdm,
        instance(sdm, inst; strict = false),
        target,
        200.0;
        threshold = false,
    ) for _ in 1:5
]
cf = hcat(cf...)
19×5 Matrix{Float64}:
   13.3455    13.2192    13.2828    13.5033    13.4549
    5.1        5.1        5.1        5.1        5.1
    0.251      0.251      0.251      0.251      0.251
  537.3      537.3      537.3      537.3      537.3
   25.85      25.85      25.85      25.85      25.85
    5.65       5.65       5.65       5.65       5.65
   20.2       20.2       20.2       20.2       20.2
   13.05      13.05      13.05      13.05      13.05
   22.25      22.25      22.25      22.25      22.25
   22.25      22.25      22.25      22.25      22.25
    8.85       8.85       8.85       8.85       8.85
 1234.41    1545.24    1460.8     1327.67    1322.7
  209.3      209.3      209.3      209.3      209.3
   16.1       16.1       16.1       16.1       16.1
   51.1       51.1       51.1       51.1       51.1
  560.1      560.1      560.1      560.1      560.1
  101.7      101.7      101.7      101.7      101.7
  101.7      101.7      101.7      101.7      101.7
  418.2      418.2      418.2      418.2      418.2

The last value (set to 200.0 here) is the learning rate, which usually needs to be tuned. The input for the observation we are interested in is, as well as five possible counterfactuals, are given in the following table:

julia
pretty_table(
    hcat(variables(sdm), instance(sdm, inst), cf[variables(sdm), :]);
    alignment = [:l, :c, :c, :c, :c, :c, :c],
    backend = :markdown,
    column_labels = ["Variable", "Obs.", "C. 1", "C. 2", "C. 3", "C. 4", "C. 5"],
    formatters = [fmt__printf("%4.1f", [2, 3, 4, 5, 6]), fmt__printf("%d", [1])],
)
VariableObs.C. 1C. 2C. 3C. 4C. 5
115.213.313.213.313.513.4549
121318.01234.41545.21460.81327.71322.7

We can check the prediction that would be made on all the counterfactuals:

julia
predict(sdm, cf)
5-element BitVector:
 1
 1
 1
 1
 1