Extensions2d#
eqxvision.layers.LayerNorm2d
#
Wraps eqx.nn.LayerNorm
for an easy to apply channelwise-variant.
__call__(self, x: Array, *, key: Optional[jax.random.PRNGKey] = None) -> Array
#
Arguments:
x
: The inputJAX
array of shape(channels, dim_0, dim_1)
key
: Ignored
Returns:
Output of eqx.nn.LayerNorm
applied to each dim_0*dim_1 x c
entry.
__init__(self, shape: Union[NoneType, int, Sequence[int]], eps: float = 1e-05, elementwise_affine: bool = True, **kwargs)
#
Arguments:
shape
: Input shape. May be left unspecified (e.g. justNone
) ifelementwise_affine=False
.eps
: Value added to denominator for numerical stability.elementwise_affine
: Whether the module has learnable affine parameters.
eqxvision.layers.Linear2d
#
Wraps eqx.nn.Linear
for an easy to apply channelwise-variant.
__init__(self, in_features: int, out_features: int, use_bias: bool = True, *, key: jax.random.PRNGKey)
#
Arguments:
in_features
: The input size.out_features
: The output size.use_bias
: Whether to add on a bias as well.key
: Ajax.random.PRNGKey
used to provide randomness for parameter initialisation. (Keyword only argument.)
__call__(self, x: Array, *, key: Optional[jax.random.PRNGKey] = None) -> Array
#
Arguments:
x
: The inputJAX
array of shape(channels, dim_0, dim_1)
key
: Ignored
Returns:
Output of eqx.nn.Linear
applied to each dim_0*dim_1 x c
entry.