LRASPP#
eqxvision.models.LRASPP
#
Implements a Lite R-ASPP Network for semantic segmentation from "Searching for MobileNetV3".
__init__(self, backbone: eqx.Module, low_channels: int, high_channels: int, num_classes: int, inter_channels: int = 128, key: Optional[jax.random.PRNGKey] = None)
#
Arguments:
-
backbone
: the network used to compute the features for the model. The intermediate layers of thebackbone
should be wrapped for obtaining intermediate features -
low_channels
: the number of channels of the low level features -
high_channels
: the number of channels of the high level features -
num_classes
: number of output classes of the model (including the background) -
inter_channels
: the number of channels for intermediate computations -
key
: Ajax.random.PRNGKey
used to provide randomness for parameter
__call__(self, x: Array, *, key: Optional[jax.random.PRNGKey] = None) -> Tuple[Any, Array]
#
Arguments:
x
: The input. Should be a JAX array with3
channelskey
: Required parameter. Utilised by few layers such asDropout
orDropPath
Returns: A tuple with outputs from the intermediate and last layers.
eqxvision.models.lraspp_mobilenet_v3_large(num_classes: Optional[int] = 21, backbone: eqx.Module = None, intermediate_layers: Callable = None, torch_weights: str = None, *, key: Optional[jax.random.PRNGKey] = None) -> LRASPP
#
Implements a Lite R-ASPP Network model with a MobileNetV3-Large backbone from Searching for MobileNetV3 paper.
Sample call
net = lraspp_mobilenet_v3_large(
backbone=mobilenet_v3_large(dilated=True),
intermediate_layers=lambda x: [4, 16],
torch_weights=SEGMENTATION_URLS['lraspp_mobilenetv3_large']
)
Arguments:
num_classes
: Number of classes in the segmentation task. Also controls the final output shape(num_classes, height, width)
. Defaults to21
backbone
: The neural network to use for extracting features. IfNone
, then all params are set toLRASPP_MobileNetV3
withuntrained
weightsintermediate_layers
: Layers frombackbone
to be used for generating output maps. Assuming the backbone is of aMobileNetV3
, default sets it to indices[4, 16]
inbackbone.features
torch_weights
: APath
orURL
for thePyTorch
weights. Defaults toNone
key
: Ajax.random.PRNGKey
used to provide randomness for parameter