Skip to content

Extensions2d#

eqxvision.layers.LayerNorm2d #

Wraps eqx.nn.LayerNorm for an easy to apply channelwise-variant.

__call__(self, x: Array, *, key: Optional[jax.random.PRNGKey] = None) -> Array #

Arguments:

  • x: The input JAX array of shape (channels, dim_0, dim_1)
  • key: Ignored

Returns:

Output of eqx.nn.LayerNorm applied to each dim_0*dim_1 x c entry.

__init__(self, shape: Union[NoneType, int, Sequence[int]], eps: float = 1e-05, elementwise_affine: bool = True, **kwargs) #

Arguments:

  • shape: Input shape. May be left unspecified (e.g. just None) if elementwise_affine=False.
  • eps: Value added to denominator for numerical stability.
  • elementwise_affine: Whether the module has learnable affine parameters.

eqxvision.layers.Linear2d #

Wraps eqx.nn.Linear for an easy to apply channelwise-variant.

__init__(self, in_features: int, out_features: int, use_bias: bool = True, *, key: jax.random.PRNGKey) #

Arguments:

  • in_features: The input size.
  • out_features: The output size.
  • use_bias: Whether to add on a bias as well.
  • key: A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)
__call__(self, x: Array, *, key: Optional[jax.random.PRNGKey] = None) -> Array #

Arguments:

  • x: The input JAX array of shape (channels, dim_0, dim_1)
  • key: Ignored

Returns:

Output of eqx.nn.Linear applied to each dim_0*dim_1 x c entry.