Skip to content

FCN#

eqxvision.models.FCN #

Ported from torchvision.models.segmentation.fcn

__init__(self, backbone: eqx.Module, classifier: eqx.Module, aux_classifier: Optional[eqx.Module] = None) #

Arguments:

  • backbone: the network used to compute the features for the model The backbone returns embedding_features(Ignored), [output features of intermediate layers].
  • classifier: module that takes last of the intermediate outputs from the backbone and returns a dense prediction
  • aux_classifier: If used, an auxiliary classifier similar to classifier for the auxiliary layer
__call__(self, x: Array, *, key: jax.random.PRNGKey) -> Tuple[Union[Any, Array], Array] #

Arguments:

  • x: The input. Should be a JAX array with 3 channels
  • key: Required parameter. Utilised by few layers such as Dropout or DropPath

Returns: A tuple with outputs from the intermediate and last layers.


eqxvision.models.fcn(num_classes: Optional[int] = 21, backbone: eqx.Module = None, intermediate_layers: Callable = None, classifier_module: eqx.Module = None, classifier_in_channels: int = 2048, aux_in_channels: int = None, silence_layers: Callable = None, torch_weights: str = None, *, key: Optional[jax.random.PRNGKey] = None) -> FCN #

Implements FCN model from Fully Convolutional Networks for Semantic Segmentation paper.

Sample call

net = fcn(
    backbone=resnet50(replace_stride_with_dilation=[False, True, True]),
    intermediate_layers=lambda x: [x.layer3, x.layer4],
    aux_in_channels=1024,
    torch_weights=SEGMENTATION_URLS["fcn_resnet50"],
)

Arguments:

  • num_classes: Number of classes in the segmentation task. Also controls the final output shape (num_classes, height, width). Defaults to 21
  • backbone: The neural network to use for extracting features. If None, then all params are set to FCN_RESNET50 with untrained weights
  • intermediate_layers: Layers from backbone to be used for generating output maps. Default sets it to layer3 and layer4 from FCN_RESNET50
  • classifier_module: Uses the FCNHead by default
  • classifier_in_channels: Number of input channels from the last intermediate layer
  • aux_in_channels: Number of channels in the auxiliary output. It is used when number of intermediate_layers is equal to 2.
  • silence_layers: Layers of a network not used in training. Typically, for a backbone ported from classification the fc layers can be dropped. This is particularly useful when loading weights from torchvision. By default, .fc layer of a model is set to identity to avoid tracking weights.
  • torch_weights: A Path or URL for the PyTorch weights. Defaults to None
  • key: A jax.random.PRNGKey used to provide randomness for parameter