Skip to content

MobileNet-V3#

eqxvision.models.MobileNetV3 #

A simple port of torchvision.models.mobilenetv3

__init__(self, inverted_residual_setting: List[_InvertedResidualConfig], last_channel: int, num_classes: int = 1000, block: Optional[eqx.Module] = None, norm_layer: Optional[eqx.Module] = None, dropout: float = 0.2, *, key: Optional[jax.random.PRNGKey] = None) #

Arguments:

  • inverted_residual_setting: Network structure
  • last_channel: The number of channels on the penultimate layer
  • num_classes: Number of classes in the classification task. Also controls the final output shape (num_classes,). Defaults to 1000
  • block: Module specifying inverted residual building block for mobilenet
  • norm_layer: Module specifying the normalization layer to use
  • dropout: The dropout probability
  • key: A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)
__call__(self, x, *, key: jax.random.PRNGKey) -> Array #

Arguments:

  • x: The input JAX array
  • key: Required parameter. Utilised by few layers such as Dropout or DropPath

eqxvision.models.mobilenet_v3_small(torch_weights: str = None, **kwargs: Any) -> MobileNetV3 #

Constructs a small MobileNetV3 architecture from Searching for MobileNetV3.

Arguments:

  • torch_weights: A Path or URL for the PyTorch weights. Defaults to None

eqxvision.models.mobilenet_v3_large(torch_weights: str = None, **kwargs: Any) -> MobileNetV3 #

Constructs a large MobileNetV3 architecture from Searching for MobileNetV3.

Arguments:

  • torch_weights: A Path or URL for the PyTorch weights. Defaults to None