Extensions2d#
eqxvision.layers.LayerNorm2d
#
Wraps eqx.nn.LayerNorm for an easy to apply channelwise-variant.
__call__(self, x: Array, *, key: Optional[jax.random.PRNGKey] = None) -> Array
#
Arguments:
x: The inputJAXarray of shape(channels, dim_0, dim_1)key: Ignored
Returns:
Output of eqx.nn.LayerNorm applied to each dim_0*dim_1 x c entry.
__init__(self, shape: Union[NoneType, int, Sequence[int]], eps: float = 1e-05, elementwise_affine: bool = True, **kwargs)
#
Arguments:
shape: Input shape. May be left unspecified (e.g. justNone) ifelementwise_affine=False.eps: Value added to denominator for numerical stability.elementwise_affine: Whether the module has learnable affine parameters.
eqxvision.layers.Linear2d
#
Wraps eqx.nn.Linear for an easy to apply channelwise-variant.
__init__(self, in_features: int, out_features: int, use_bias: bool = True, *, key: jax.random.PRNGKey)
#
Arguments:
in_features: The input size.out_features: The output size.use_bias: Whether to add on a bias as well.key: Ajax.random.PRNGKeyused to provide randomness for parameter initialisation. (Keyword only argument.)
__call__(self, x: Array, *, key: Optional[jax.random.PRNGKey] = None) -> Array
#
Arguments:
x: The inputJAXarray of shape(channels, dim_0, dim_1)key: Ignored
Returns:
Output of eqx.nn.Linear applied to each dim_0*dim_1 x c entry.