Skip to content

LRASPP#

eqxvision.models.LRASPP #

Implements a Lite R-ASPP Network for semantic segmentation from "Searching for MobileNetV3".

__init__(self, backbone: eqx.Module, low_channels: int, high_channels: int, num_classes: int, inter_channels: int = 128, key: Optional[jax.random.PRNGKey] = None) #

Arguments:

  • backbone: the network used to compute the features for the model. The intermediate layers of the backbone should be wrapped for obtaining intermediate features

  • low_channels: the number of channels of the low level features

  • high_channels: the number of channels of the high level features

  • num_classes: number of output classes of the model (including the background)

  • inter_channels: the number of channels for intermediate computations

  • key: A jax.random.PRNGKey used to provide randomness for parameter

__call__(self, x: Array, *, key: Optional[jax.random.PRNGKey] = None) -> Tuple[Any, Array] #

Arguments:

  • x: The input. Should be a JAX array with 3 channels
  • key: Required parameter. Utilised by few layers such as Dropout or DropPath

Returns: A tuple with outputs from the intermediate and last layers.


eqxvision.models.lraspp_mobilenet_v3_large(num_classes: Optional[int] = 21, backbone: eqx.Module = None, intermediate_layers: Callable = None, torch_weights: str = None, *, key: Optional[jax.random.PRNGKey] = None) -> LRASPP #

Implements a Lite R-ASPP Network model with a MobileNetV3-Large backbone from Searching for MobileNetV3 paper.

Sample call

net = lraspp_mobilenet_v3_large(
    backbone=mobilenet_v3_large(dilated=True),
    intermediate_layers=lambda x: [4, 16],
    torch_weights=SEGMENTATION_URLS['lraspp_mobilenetv3_large']
)

Arguments:

  • num_classes: Number of classes in the segmentation task. Also controls the final output shape (num_classes, height, width). Defaults to 21
  • backbone: The neural network to use for extracting features. If None, then all params are set to LRASPP_MobileNetV3 with untrained weights
  • intermediate_layers: Layers from backbone to be used for generating output maps. Assuming the backbone is of a MobileNetV3, default sets it to indices [4, 16] in backbone.features
  • torch_weights: A Path or URL for the PyTorch weights. Defaults to None
  • key: A jax.random.PRNGKey used to provide randomness for parameter