DeepLabV3#
eqxvision.models.DeepLabV3
#
Ported from torchvision.models.segmentation.deeplabv3
__init__(self, backbone: eqx.Module, classifier: eqx.Module, aux_classifier: Optional[eqx.Module] = None)
#
Arguments:
backbone: the network used to compute the features for the model The backbone returnsembedding_features(Ignored),[output features of intermediate layers].classifier: module that takes last of the intermediate outputs from the backbone and returns a dense predictionaux_classifier: If used, an auxiliary classifier similar toclassifierfor the auxiliary layer
__call__(self, x: Array, *, key: jax.random.PRNGKey) -> Tuple[Union[Any, Array], Array]
#
Arguments:
x: The input. Should be a JAX array with3channelskey: Required parameter. Utilised by few layers such asDropoutorDropPath
Returns: A tuple with outputs from the intermediate and last layers.
eqxvision.models.deeplabv3(num_classes: Optional[int] = 21, backbone: eqx.Module = None, intermediate_layers: Callable = None, classifier_module: eqx.Module = None, classifier_in_channels: int = 2048, aux_classifier_module: eqx.Module = None, aux_in_channels: int = 1024, silence_layers: Callable = None, torch_weights: str = None, *, key: Optional[jax.random.PRNGKey] = None) -> DeepLabV3
#
Implements DeepLabV3 model from Rethinking Atrous Convolution for Semantic Image Segmentation paper.
Sample call
net = deeplabv3(
backbone=resnet50(replace_stride_with_dilation=[False, True, True]),
intermediate_layers=lambda x: [x.layer3, x.layer4],
aux_in_channels=1024,
torch_weights=SEGMENTATION_URLS["deeplabv3_resnet50"]
)
Arguments:
num_classes: Number of classes in the segmentation task. Also controls the final output shape(num_classes, height, width). Defaults to21backbone: The neural network to use for extracting features. IfNone, then all params are set toDeepLabV3_RESNET50withuntrainedweightsintermediate_layers: Layers frombackboneto be used for generating output maps. Default sets it tolayer3andlayer4fromDeepLabV3_RESNET50classifier_module: Uses theDeepLabHeadby defaultclassifier_in_channels: Number of input channels from the last intermediate layeraux_classifier_module: Uses theFCNHeadby defaultaux_in_channels: Number of channels in the auxiliary output. It is used when number of intermediate_layers is equal to 2.silence_layers: Layers of a network not used in training. Typically, for a backbone ported from classification thefclayers can be dropped. This is particularly useful when loading weights fromtorchvision. By default,.fclayer of a model is set to identity to avoid tracking weights.torch_weights: APathorURLfor thePyTorchweights. Defaults toNonekey: Ajax.random.PRNGKeyused to provide randomness for parameter