FCN#
eqxvision.models.FCN
#
Ported from torchvision.models.segmentation.fcn
__init__(self, backbone: eqx.Module, classifier: eqx.Module, aux_classifier: Optional[eqx.Module] = None)
#
Arguments:
backbone
: the network used to compute the features for the model The backbone returnsembedding_features(Ignored)
,[output features of intermediate layers]
.classifier
: module that takes last of the intermediate outputs from the backbone and returns a dense predictionaux_classifier
: If used, an auxiliary classifier similar toclassifier
for the auxiliary layer
__call__(self, x: Array, *, key: jax.random.PRNGKey) -> Tuple[Union[Any, Array], Array]
#
Arguments:
x
: The input. Should be a JAX array with3
channelskey
: Required parameter. Utilised by few layers such asDropout
orDropPath
Returns: A tuple with outputs from the intermediate and last layers.
eqxvision.models.fcn(num_classes: Optional[int] = 21, backbone: eqx.Module = None, intermediate_layers: Callable = None, classifier_module: eqx.Module = None, classifier_in_channels: int = 2048, aux_in_channels: int = None, silence_layers: Callable = None, torch_weights: str = None, *, key: Optional[jax.random.PRNGKey] = None) -> FCN
#
Implements FCN model from Fully Convolutional Networks for Semantic Segmentation paper.
Sample call
net = fcn(
backbone=resnet50(replace_stride_with_dilation=[False, True, True]),
intermediate_layers=lambda x: [x.layer3, x.layer4],
aux_in_channels=1024,
torch_weights=SEGMENTATION_URLS["fcn_resnet50"],
)
Arguments:
num_classes
: Number of classes in the segmentation task. Also controls the final output shape(num_classes, height, width)
. Defaults to21
backbone
: The neural network to use for extracting features. IfNone
, then all params are set toFCN_RESNET50
withuntrained
weightsintermediate_layers
: Layers frombackbone
to be used for generating output maps. Default sets it tolayer3
andlayer4
fromFCN_RESNET50
classifier_module
: Uses theFCNHead
by defaultclassifier_in_channels
: Number of input channels from the last intermediate layeraux_in_channels
: Number of channels in the auxiliary output. It is used when number of intermediate_layers is equal to 2.silence_layers
: Layers of a network not used in training. Typically, for a backbone ported from classification thefc
layers can be dropped. This is particularly useful when loading weights fromtorchvision
. By default,.fc
layer of a model is set to identity to avoid tracking weights.torch_weights
: APath
orURL
for thePyTorch
weights. Defaults toNone
key
: Ajax.random.PRNGKey
used to provide randomness for parameter