Adversarial Attack (FGSM)#
We will be making our way through generating an adversarial example using the Fast-Gradient Sign Method (FGSM).
Based on Torchvision tutorial, check it out for a more in-depth understanding of adversarial examples.
The flow of this tutorial will be as:
- Prepare input image
- Initialise a model
- Compute FGSM
- Visualize the results
Installing Dependencies#
!pip install eqxvision optax jaxlib --quiet
Required Imports#
from io import BytesIO
from urllib.request import urlopen
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jrandom
import matplotlib.pyplot as plt
import requests
from PIL import Image
import eqxvision as eqv
from eqxvision.utils import CLASSIFICATION_URLS
%matplotlib inline
import optax
from torchvision import transforms
Preparing Image & Transforms#
# Download
response = requests.get(
"https://cdn.britannica.com/80/150980-050-84B9202C/Giant-panda-cub-branch.jpg"
)
img = Image.open(BytesIO(response.content))
img = img.convert("RGB")
# Transform
mean = (0.485, 0.456, 0.406)
std_dev = (0.229, 0.224, 0.225)
transform = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean, std_dev),
]
)
img = jnp.asarray(transform(img).unsqueeze(0).numpy())
def inv_transform(x):
means = jnp.asarray(mean).reshape(3, -1)
std_devs = jnp.asarray(std_dev).reshape(3, -1)
func = lambda x, m, s: s * x + m
x = jax.vmap(func)(x, means, std_devs)
return jnp.transpose(x, (1, 2, 0))
Prediction to Class Names#
# Prediction to class name mapping
cls_map_link = (
"https://github.com/Waikato/wekaDeeplearning4j/blob/master/src/"
"main/resources/class-maps/IMAGENET.txt?raw=True"
)
data = urlopen(cls_map_link).read().decode("utf-8").split("\n")
cls_map = {}
for i, line in enumerate(data):
cls_map[i] = line.strip().split(",")[0]
Initialising Model#
model = eqv.models.resnet50(CLASSIFICATION_URLS["resnet50"])
model = eqx.tree_inference(model, True)
key = jrandom.split(jrandom.PRNGKey(0), 1)
@eqx.filter_value_and_grad
def compute_loss(x, model, y, keys):
logits = jax.vmap(model, axis_name=("batch"))(x, key=keys)
one_hot_actual = jax.nn.one_hot(y, num_classes=1000)
return optax.softmax_cross_entropy(logits, one_hot_actual).mean()
FGSM Computation#
For an image \(x\) and model \(f\), the adversarial example is:
\(x_{adv} = x + ϵ \times sign(\nabla_x f(x))\)
output = jax.vmap(model, axis_name="batch")(img, key=key)
pred_cls = jnp.argmax(output, axis=1).item()
print(f"Original Category: {cls_map[pred_cls]}")
loss, grads = compute_loss(img, model, pred_cls, key)
del_x = 0.1 * jnp.sign(grads)
adv_img = img + del_x
output = jax.vmap(model, axis_name="batch")(adv_img, key=key)
pred_adv = jnp.argmax(output, axis=1).item()
print(f"Adversarial Category: {cls_map[pred_adv]}")
Visualisation#
axs = plt.figure(constrained_layout=True, figsize=(10, 18)).subplots(
1, 3, sharex=True, sharey=True
)
axs[0].imshow(inv_transform(img[0]))
axs[1].matshow(jnp.transpose(del_x[0], (1, 2, 0)))
axs[2].imshow(inv_transform(adv_img[0]))
axs[0].set_title(f"{cls_map[pred_cls]}")
axs[1].set_title("grad. image")
axs[2].set_title(f"{cls_map[pred_adv]}")
axs[0].axis("off")
axs[1].axis("off")
axs[2].axis("off")
plt.show()
That's all Folks