Skip to content

AlexNet#

eqxvision.models.AlexNet #

A simple port of torchvision.models.alexnet

__init__(self, num_classes: int = 1000, dropout: float = 0.5, *, key: Optional[jax.random.PRNGKey] = None) #

Arguments:

  • num_classes: Number of classes in the classification task. Also controls the final output shape (num_classes,). Defaults to 1000
  • dropout: Parameter used for the equinox.nn.Dropout layers. Defaults to 0.5
  • key: A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)
__call__(self, x: Array, *, key: jax.random.PRNGKey) -> Array #

Arguments:

  • x: The input. Should be a JAX array with 3 channels
  • key: Required parameter

eqxvision.models.alexnet(torch_weights: str = None, **kwargs: Any) -> AlexNet #

AlexNet model architecture from the One weird trick...` paper. The required minimum input size of the model is 63x63. Arguments:

  • torch_weights: A Path or URL for the PyTorch weights. Defaults to None