Skip to content

Patch-Embed#

eqxvision.layers.PatchEmbed #

2D Image to Patch Embedding ported from Timm

__init__(self, img_size: Union[int, Tuple[int]] = 224, patch_size: Union[int, Tuple[int]] = 16, in_chans: int = 3, embed_dim: int = 768, norm_layer: eqx.Module = None, flatten: bool = True, *, key: Optional[jax.random.PRNGKey] = None) #

Arguments:

  • img_size: The size of the input image. Defaults to (224, 224)
  • patch_size: Size of the patch to construct from the input image. Defaults to (16, 16)
  • in_chans: Number of input channels. Defaults to 3
  • embed_dim: The dimension of the resulting embedding of the patch. Defaults to 768
  • norm_layer: The normalisation to be applied on an input. Defaults to None
  • flatten: If enabled, the 2d patches are flattened to 1d
  • key: A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)
__call__(self, x: Array, *, key: Optional[jax.random.PRNGKey] = None) -> Array #

Arguments:

  • x: The input. Should be a JAX array of shape(in_chans, img_size[0], img_size[1]).
  • key: Ignored