AlexNet#
eqxvision.models.AlexNet
#
A simple port of torchvision.models.alexnet
__init__(self, num_classes: int = 1000, dropout: float = 0.5, *, key: Optional[jax.random.PRNGKey] = None)
#
Arguments:
num_classes: Number of classes in the classification task. Also controls the final output shape(num_classes,). Defaults to1000dropout: Parameter used for theequinox.nn.Dropoutlayers. Defaults to0.5key: Ajax.random.PRNGKeyused to provide randomness for parameter initialisation. (Keyword only argument.)
__call__(self, x: Array, *, key: jax.random.PRNGKey) -> Array
#
Arguments:
x: The input. Should be a JAX array with3channelskey: Required parameter
eqxvision.models.alexnet(torch_weights: str = None, **kwargs: Any) -> AlexNet
#
AlexNet model architecture from the One weird trick...` paper. The required minimum input size of the model is 63x63. Arguments:
torch_weights: APathorURLfor thePyTorchweights. Defaults toNone