Counterfactuals#

Comparison of counterfactual explanations#

The following example show a few ways of computing counterfactual explanations.

import numpy as np
import matplotlib.pylab as plt
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier

from wildboar.ensemble import ShapeletForestClassifier
from wildboar.datasets import load_dataset
from wildboar.explain.counterfactual import counterfactuals

random_state = 1234
x, y = load_dataset("GunPoint")

x_train, x_test, y_train, y_test = train_test_split(
    x, y, test_size=0.2, random_state=random_state
)

classifiers = [
    (
        "Shapelet Forest Counterfactuals",
        ShapeletForestClassifier(
            metric="euclidean", random_state=random_state, n_estimators=100
        ),
    ),
    ("KNearest Counterfactuals", KNeighborsClassifier(metric="euclidean")),
    ("Prototype Counterfactuals", RandomForestClassifier(random_state=random_state)),
]

fig, ax = plt.subplots(nrows=3, sharex=True)
label = np.unique(y_train)[0]
for i, (name, clf) in enumerate(classifiers):
    clf.fit(x_train, y_train)
    x_test_sample = x_test[y_test != label]
    if isinstance(clf, RandomForestClassifier):
        kwargs = {"background_x": x_train, "background_y": y_train}
    else:
        kwargs = {}

    x_counterfactual, valid = counterfactuals(
        clf, x_test_sample, label, random_state=random_state, **kwargs
    )
    ax[i].set_title(name + ("(invalid)" if not valid[0] else ""))
    ax[i].plot(x_counterfactual[0], c="red")
    ax[i].plot(x_test[0], c="blue")

ax[-1].legend(["x'", "x"])
plt.savefig("../fig/counterfactuals.png")
../_images/counterfactuals.png