Skip to content

Mlp-Projection#

eqxvision.layers.MlpProjection #

MLP as used in Vision Transformer, MLP-Mixer and related networks

__init__(self, in_features: int, hidden_features: int = None, out_features: int = None, lin_layer: Union[Linear2d, nn.Linear] = <class 'equinox.nn.linear.Linear'>, act_layer: Callable = None, drop: Union[float, Tuple[float]] = 0.0, *, key: jax.random.PRNGKey = None) #

Arguments:

  • in_features: The expected dimension of the input
  • hidden_features: Dimensionality of the hidden layer
  • out_features: The dimension of the output feature
  • lin_layer: Linear layer to use. For transformer like architectures, Linear2d can be easier to integrate.
  • act_layer: Activation function to be applied to the intermediate layers
  • drop: The probability associated with Dropout
  • key: A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)
__call__(self, x: Array, *, key: jax.random.PRNGKey) -> Array #

Arguments:

  • x: The input JAX array
  • key: Utilised by few layers in the network such as Dropout or DropPath