Skip to content

GoogLeNet#

eqxvision.models.GoogLeNet #

A simple port of torchvision.models.GoogLeNet

__init__(self, num_classes: int = 1000, aux_logits: bool = False, blocks: Optional[List[eqx.Module]] = None, dropout: float = 0.2, dropout_aux: float = 0.7, *, key: Optional[jax.random.PRNGKey] = None) #

Arguments:

  • num_classes: Number of classes in the classification task. Also controls the final output shape (num_classes,). Defaults to 1000
  • aux_logits: If True, two auxiliary branches are added to the network. Defaults to False
  • blocks: Blocks for constructing the network
  • dropout: Dropout applied on the main branch. Defaults to 0.2
  • dropout_aux: Dropout applied on the aux branches. Defaults to 0.7
  • key: A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)
__call__(self, x: Array, *, key: jax.random.PRNGKey) -> Optional[Array] #

Arguments:

  • x: The input. Should be a JAX array with 3 channels
  • key: Required parameter. Utilised by few layers such as Dropout or DropPath

eqxvision.models.googlenet(torch_weights: str = None, **kwargs: Any) -> GoogLeNet #

GoogLeNet (Inception v1) model architecture from Going Deeper with Convolutions. The required minimum input size of the model is 15x15.

Arguments:

  • torch_weights: A Path or URL for the PyTorch weights. Defaults to None