Skip to content

Transfer Learning#

Here we will try to hit two birds with one stone - Prepare a tutorial on transfer learning. - Verify the goodness of VGG models

To those of you not aware, pretrained VGGs in eqxvision perform poorly in comparison to torchvision counterparts. The main reason is due to differences in implementation of Equinox's adaptive average pooling

The flow of this tutorial will be as:

  • Preparing train/val datasets
  • Preparing model
  • Setting up forward and loss computation methods
  • Initialising the optimizer
  • Model Training
  • Verifying the integrity of weights

Installing Dependencies#

!pip install eqxvision optax jaxlib --quiet
     |████████████████████████████████| 145 kB 13.4 MB/s 
     |████████████████████████████████| 66 kB 5.1 MB/s 
     |████████████████████████████████| 76 kB 6.0 MB/s 

Basic Imports#

import functools as ft

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jrandom
import optax
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import eqxvision as eqv
from eqxvision.utils import CLASSIFICATION_URLS

Hyper-parameters#

BATCH_SIZE = 128
LR = 0.001
EPOCHS = 5

Dataset & Dataloaders#

train_transform = transforms.Compose(
    [
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ]
)
val_transform = transforms.Compose(
    [
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ]
)
train_dataset = datasets.STL10(
    root="/tmp", split="train", transform=train_transform, download=True
)
val_dataset = datasets.STL10(
    root="/tmp", split="test", transform=val_transform, download=True
)

train_loader = DataLoader(dataset=train_dataset, num_workers=2, batch_size=BATCH_SIZE)
val_loader = DataLoader(dataset=val_dataset, num_workers=2, batch_size=BATCH_SIZE)
Downloading http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz to /tmp/stl10_binary.tar.gz

  0%|          | 0/2640397119 [00:00<?, ?it/s]
Extracting /tmp/stl10_binary.tar.gz to /tmp
Files already downloaded and verified

Model Prep.#

We need to perform two steps after initialising the model.

  1. Replace the final classification layer to suit the STL-10 dataset.
  2. Freeze the parameters for all layers except the classification layer.
model = eqv.models.vgg11(torch_weights=CLASSIFICATION_URLS["vgg11"])

# Replacing the last layer for STL-10
model = eqx.tree_at(
    lambda m: m.classifier,
    model,
    (eqx.nn.Linear(512 * 7 * 7, 10, key=jrandom.PRNGKey(0))),
)

# Freezing the model except for the last layer
filter_spec = jax.tree_util.tree_map(lambda _: False, model)
filter_spec = eqx.tree_at(
    lambda tree: (tree.classifier.weight, tree.classifier.bias),
    filter_spec,
    replace=(True, True),
)
  0%|          | 0.00/507M [00:00<?, ?B/s]

Utility Methods#

The filter_spec decides the params w.r.t to which the gradient is computed. Here, we will be computing gradient w.r.t to only the classifier module.

Check here for more details.

@ft.partial(eqx.filter_value_and_grad, arg=filter_spec)
def compute_loss(model, x, y, keys):
    logits = jax.vmap(model, axis_name=("batch"))(x, key=keys)
    one_hot_actual = jax.nn.one_hot(y, num_classes=10)
    return optax.softmax_cross_entropy(logits, one_hot_actual).mean()


@eqx.filter_jit
def make_step(model, x, y, keys, optimizer, opt_state):
    loss, grads = compute_loss(model, x, y, keys)
    updates, opt_state = optimizer.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return loss, model, opt_state


def accuracy(model, loader):
    correct = 0.0
    total = 0.0
    for images, labels in loader:
        keys = jrandom.split(jrandom.PRNGKey(0), images.shape[0])
        output = jax.vmap(model, axis_name="batch")(
            jnp.asarray(images.numpy()), key=keys
        )
        pred = jnp.argmax(output, axis=1)

        correct += jnp.sum(pred == labels.numpy())
        total += images.shape[0]
    return correct / total

Optimizer & Scheduler#

The important bit to remember is wrapping the model in eqx.filter before passing it on to the optimizer. This step will fail if you forget the filter.

total_steps = EPOCHS * (len(train_loader.dataset) // BATCH_SIZE) + EPOCHS
cosine_decay_scheduler = optax.cosine_decay_schedule(
    LR, decay_steps=total_steps, alpha=0.95
)
optimizer = optax.adam(learning_rate=cosine_decay_scheduler)
opt_state = optimizer.init(
    eqx.filter(model, eqx.is_array)
)  # Wrap in a fileter to avoid passing non-JAX types

The Training#

loss = 0

for epoch in range(EPOCHS):
    for step, (x, y) in enumerate(train_loader):
        key = jrandom.PRNGKey(epoch + x.shape[0] * step)
        keys = jrandom.split(key, x.shape[0])
        loss_value, model, opt_state = make_step(
            model, jnp.asarray(x), jnp.asarray(y), keys, optimizer, opt_state
        )
        loss = 0.9 * loss + 0.1 * loss_value.item()

    model = eqx.tree_inference(model, True)  # Analogous to model.eval()
    train_acc = accuracy(model, train_loader)
    test_acc = accuracy(model, val_loader)
    model = eqx.tree_inference(model, False)  # Back to training mode

    print(
        f"Epoch={epoch}, loss={loss:.4f}, tr.acc={train_acc.item():.4f}, te.acc={test_acc.item():.4f}"
    )
Epoch=0, loss=0.6885, tr.acc=0.8134, te.acc=0.9073
Epoch=1, loss=0.5187, tr.acc=0.8174, te.acc=0.9018
Epoch=2, loss=0.4801, tr.acc=0.8450, te.acc=0.9125
Epoch=3, loss=0.4588, tr.acc=0.8560, te.acc=0.9176
Epoch=4, loss=0.4834, tr.acc=0.8476, te.acc=0.9059

Verify Weights#

The last bit is to verify that weights of model.features is unchanged and only model.classifier is updated.

base_model = eqx.tree_inference(
    eqv.models.vgg11(torch_weights=CLASSIFICATION_URLS["vgg11"]), True
)
assert eqx.tree_equal(base_model.features, model.features)

That's all Folks