Patch-Embed#
eqxvision.layers.PatchEmbed
#
2D Image to Patch Embedding ported from Timm
__init__(self, img_size: Union[int, Tuple[int]] = 224, patch_size: Union[int, Tuple[int]] = 16, in_chans: int = 3, embed_dim: int = 768, norm_layer: eqx.Module = None, flatten: bool = True, *, key: Optional[jax.random.PRNGKey] = None)
#
Arguments:
img_size: The size of the input image. Defaults to(224, 224)patch_size: Size of the patch to construct from the input image. Defaults to(16, 16)in_chans: Number of input channels. Defaults to3embed_dim: The dimension of the resulting embedding of the patch. Defaults to768norm_layer: The normalisation to be applied on an input. Defaults to Noneflatten: If enabled, the2dpatches are flattened to1dkey: Ajax.random.PRNGKeyused to provide randomness for parameter initialisation. (Keyword only argument.)
__call__(self, x: Array, *, key: Optional[jax.random.PRNGKey] = None) -> Array
#
Arguments:
x: The input. Should be a JAX array of shape(in_chans, img_size[0], img_size[1]).key: Ignored