Conv-Norm-Activation#
eqxvision.layers.ConvNormActivation
#
A simple port of torchvision.ops.misc.ConvNormActivation.
Packs convolution -> normalisation -> activation into one easy to use module.
__init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, padding: Optional[int] = None, groups: int = 1, norm_layer: Optional[Callable] = <class 'equinox.experimental.batch_norm.BatchNorm'>, activation_layer: Optional[Callable] = <function relu>, dilation: int = 1, use_bias: Optional[bool] = None, *, key: jax.random.PRNGKey = None)
#
in_channels: Number of channels in the input imageout_channels: Number of channels produced by the Convolution-Normalzation-Activation blockkernel_size: Size of the convolution kernel. Defaults to3stride: Stride of the convolution. Defaults to1padding: Padding added to all four sides of the input. Defaults toNone, in which case it will calculated aspadding = (kernel_size - 1) // 2 * dilationgroups: Number of blocked connections from input channels to output channels. Defaults to1norm_layer: Norm layer that will be stacked on top of the convolution layer. IfNonethis layer wont be used. Defaults toeqx.experimental.BatchNormactivation_layer: Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer IfNonethis layer wont be used. Defaults tojax.nn.reludilation: Spacing between kernel elements. Defaults to1bias: IfTrue, bias is used in the convolution layer. By default, biases are included # ifnorm_layer is Nonekey: Ajax.random.PRNGKeyused to provide randomness for parameter initialisation. (Keyword only argument.)
__call__(self, x: Any, *, key: Optional[jax.random.PRNGKey] = None) -> Any
#
Arguments:
x: Argument passed to the first member of the sequence.key: Ajax.random.PRNGKey, which will be split and passed to every layer to provide any desired randomness. (Optional. Keyword only argument.)
Returns:
The output of the last member of the sequence.