AlexNet#
eqxvision.models.AlexNet
#
A simple port of torchvision.models.alexnet
__init__(self, num_classes: int = 1000, dropout: float = 0.5, *, key: Optional[jax.random.PRNGKey] = None)
#
Arguments:
num_classes
: Number of classes in the classification task. Also controls the final output shape(num_classes,)
. Defaults to1000
dropout
: Parameter used for theequinox.nn.Dropout
layers. Defaults to0.5
key
: Ajax.random.PRNGKey
used to provide randomness for parameter initialisation. (Keyword only argument.)
__call__(self, x: Array, *, key: jax.random.PRNGKey) -> Array
#
Arguments:
x
: The input. Should be a JAX array with3
channelskey
: Required parameter
eqxvision.models.alexnet(torch_weights: str = None, **kwargs: Any) -> AlexNet
#
AlexNet model architecture from the One weird trick...` paper. The required minimum input size of the model is 63x63. Arguments:
torch_weights
: APath
orURL
for thePyTorch
weights. Defaults toNone