Eqxvision#
Eqxvision is a package of popular computer vision model architectures built using Equinox.
Installation#
Use the package manager pip to install eqxvision.
pip install eqxvision
requires: python>=3.7
Usage#
Example
Importing and doing a forward pass is as simple as
import jax
import jax.random as jr
import equinox as eqx
from eqxvision.models import alexnet
from eqxvision.utils import CLASSIFICATION_URLS
@eqx.filter_jit
def forward(net, images, key):
keys = jax.random.split(key, images.shape[0])
output = jax.vmap(net, axis_name=('batch'))(images, key=keys)
...
net = alexnet(torch_weights=CLASSIFICATION_URLS['alexnet'])
images = jr.uniform(jr.PRNGKey(0), shape=(1,3,224,224))
output = forward(net, images, jr.PRNGKey(0))
What's New?#
FCN
,DeepLabV3
andLRASPP
segmentation models are now supported (checkout the tutorial).- Backward incompatible changes to
v0.2.0
for loading apretrained
model. - Almost all image classification models are ported from
torchvision
. - New tutorial for generating adversarial examples and others coming soon.
Get Started!#
Start with any one of these easy to follow tutorials.
Tips#
- Better to use
@equinox.filter_jit
instead of@jax.jit
. - Use
jax.{v,p}map
withaxis_name='batch'
when using models that use batch normalisation. - Don't forget to switch to
inference
mode for evaluations. (model = eqx.tree_inference(model)
) - Initialise Optax optimisers as
optim.init(eqx.filter(net, eqx.is_array))
. (See here.)
Contributing#
Pull requests are welcome. For major changes, please open an issue first to discuss what you would like to change.
Development Process#
If you plan to modify the code or documentation, please follow the steps below:
- Fork the repository and create your branch from
dev
. - If you have modified the code (new feature or bug-fix), please add unit tests.
- If you have changed APIs, update the documentation. Make sure the documentation builds.
mkdocs serve
- Ensure the test suite passes.
pytest tests -vvv
- Make sure your code passes the formatting checks. Automatically checked with a
pre-commit
hook.