Transfer Learning#
Here we will try to hit two birds with one stone
- Prepare a tutorial on transfer learning.
- Verify the goodness of VGG
models
To those of you not aware, pretrained VGGs in eqxvision
perform poorly in comparison to torchvision
counterparts. The main reason is due to differences in implementation of Equinox's adaptive average pooling
The flow of this tutorial will be as:
- Preparing train/val datasets
- Preparing model
- Setting up forward and loss computation methods
- Initialising the optimizer
- Model Training
- Verifying the integrity of weights
Installing Dependencies#
!pip install eqxvision optax jaxlib --quiet
Basic Imports#
import functools as ft
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jrandom
import optax
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import eqxvision as eqv
from eqxvision.utils import CLASSIFICATION_URLS
Hyper-parameters#
BATCH_SIZE = 128
LR = 0.001
EPOCHS = 5
Dataset & Dataloaders#
train_transform = transforms.Compose(
[
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
]
)
val_transform = transforms.Compose(
[
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
]
)
train_dataset = datasets.STL10(
root="/tmp", split="train", transform=train_transform, download=True
)
val_dataset = datasets.STL10(
root="/tmp", split="test", transform=val_transform, download=True
)
train_loader = DataLoader(dataset=train_dataset, num_workers=2, batch_size=BATCH_SIZE)
val_loader = DataLoader(dataset=val_dataset, num_workers=2, batch_size=BATCH_SIZE)
Model Prep.#
We need to perform two steps after initialising the model.
- Replace the final classification layer to suit the
STL-10
dataset. - Freeze the parameters for all layers except the classification layer.
model = eqv.models.vgg11(torch_weights=CLASSIFICATION_URLS["vgg11"])
# Replacing the last layer for STL-10
model = eqx.tree_at(
lambda m: m.classifier,
model,
(eqx.nn.Linear(512 * 7 * 7, 10, key=jrandom.PRNGKey(0))),
)
# Freezing the model except for the last layer
filter_spec = jax.tree_util.tree_map(lambda _: False, model)
filter_spec = eqx.tree_at(
lambda tree: (tree.classifier.weight, tree.classifier.bias),
filter_spec,
replace=(True, True),
)
@ft.partial(eqx.filter_value_and_grad, arg=filter_spec)
def compute_loss(model, x, y, keys):
logits = jax.vmap(model, axis_name=("batch"))(x, key=keys)
one_hot_actual = jax.nn.one_hot(y, num_classes=10)
return optax.softmax_cross_entropy(logits, one_hot_actual).mean()
@eqx.filter_jit
def make_step(model, x, y, keys, optimizer, opt_state):
loss, grads = compute_loss(model, x, y, keys)
updates, opt_state = optimizer.update(grads, opt_state)
model = eqx.apply_updates(model, updates)
return loss, model, opt_state
def accuracy(model, loader):
correct = 0.0
total = 0.0
for images, labels in loader:
keys = jrandom.split(jrandom.PRNGKey(0), images.shape[0])
output = jax.vmap(model, axis_name="batch")(
jnp.asarray(images.numpy()), key=keys
)
pred = jnp.argmax(output, axis=1)
correct += jnp.sum(pred == labels.numpy())
total += images.shape[0]
return correct / total
Optimizer & Scheduler#
The important bit to remember is wrapping the model in eqx.filter
before passing it on to the optimizer. This step will fail
if you forget the filter.
total_steps = EPOCHS * (len(train_loader.dataset) // BATCH_SIZE) + EPOCHS
cosine_decay_scheduler = optax.cosine_decay_schedule(
LR, decay_steps=total_steps, alpha=0.95
)
optimizer = optax.adam(learning_rate=cosine_decay_scheduler)
opt_state = optimizer.init(
eqx.filter(model, eqx.is_array)
) # Wrap in a fileter to avoid passing non-JAX types
The Training#
loss = 0
for epoch in range(EPOCHS):
for step, (x, y) in enumerate(train_loader):
key = jrandom.PRNGKey(epoch + x.shape[0] * step)
keys = jrandom.split(key, x.shape[0])
loss_value, model, opt_state = make_step(
model, jnp.asarray(x), jnp.asarray(y), keys, optimizer, opt_state
)
loss = 0.9 * loss + 0.1 * loss_value.item()
model = eqx.tree_inference(model, True) # Analogous to model.eval()
train_acc = accuracy(model, train_loader)
test_acc = accuracy(model, val_loader)
model = eqx.tree_inference(model, False) # Back to training mode
print(
f"Epoch={epoch}, loss={loss:.4f}, tr.acc={train_acc.item():.4f}, te.acc={test_acc.item():.4f}"
)
Verify Weights#
The last bit is to verify that weights of model.features
is unchanged and only model.classifier
is updated.
base_model = eqx.tree_inference(
eqv.models.vgg11(torch_weights=CLASSIFICATION_URLS["vgg11"]), True
)
assert eqx.tree_equal(base_model.features, model.features)
That's all Folks