LRASPP#
eqxvision.models.LRASPP
#
Implements a Lite R-ASPP Network for semantic segmentation from "Searching for MobileNetV3".
__init__(self, backbone: eqx.Module, low_channels: int, high_channels: int, num_classes: int, inter_channels: int = 128, key: Optional[jax.random.PRNGKey] = None)
#
Arguments:
-
backbone: the network used to compute the features for the model. The intermediate layers of thebackboneshould be wrapped for obtaining intermediate features -
low_channels: the number of channels of the low level features -
high_channels: the number of channels of the high level features -
num_classes: number of output classes of the model (including the background) -
inter_channels: the number of channels for intermediate computations -
key: Ajax.random.PRNGKeyused to provide randomness for parameter
__call__(self, x: Array, *, key: Optional[jax.random.PRNGKey] = None) -> Tuple[Any, Array]
#
Arguments:
x: The input. Should be a JAX array with3channelskey: Required parameter. Utilised by few layers such asDropoutorDropPath
Returns: A tuple with outputs from the intermediate and last layers.
eqxvision.models.lraspp_mobilenet_v3_large(num_classes: Optional[int] = 21, backbone: eqx.Module = None, intermediate_layers: Callable = None, torch_weights: str = None, *, key: Optional[jax.random.PRNGKey] = None) -> LRASPP
#
Implements a Lite R-ASPP Network model with a MobileNetV3-Large backbone from Searching for MobileNetV3 paper.
Sample call
net = lraspp_mobilenet_v3_large(
backbone=mobilenet_v3_large(dilated=True),
intermediate_layers=lambda x: [4, 16],
torch_weights=SEGMENTATION_URLS['lraspp_mobilenetv3_large']
)
Arguments:
num_classes: Number of classes in the segmentation task. Also controls the final output shape(num_classes, height, width). Defaults to21backbone: The neural network to use for extracting features. IfNone, then all params are set toLRASPP_MobileNetV3withuntrainedweightsintermediate_layers: Layers frombackboneto be used for generating output maps. Assuming the backbone is of aMobileNetV3, default sets it to indices[4, 16]inbackbone.featurestorch_weights: APathorURLfor thePyTorchweights. Defaults toNonekey: Ajax.random.PRNGKeyused to provide randomness for parameter