Patch-Embed#
eqxvision.layers.PatchEmbed
#
2D Image to Patch Embedding ported from Timm
__init__(self, img_size: Union[int, Tuple[int]] = 224, patch_size: Union[int, Tuple[int]] = 16, in_chans: int = 3, embed_dim: int = 768, norm_layer: eqx.Module = None, flatten: bool = True, *, key: Optional[jax.random.PRNGKey] = None)
#
Arguments:
img_size
: The size of the input image. Defaults to(224, 224)
patch_size
: Size of the patch to construct from the input image. Defaults to(16, 16)
in_chans
: Number of input channels. Defaults to3
embed_dim
: The dimension of the resulting embedding of the patch. Defaults to768
norm_layer
: The normalisation to be applied on an input. Defaults to Noneflatten
: If enabled, the2d
patches are flattened to1d
key
: Ajax.random.PRNGKey
used to provide randomness for parameter initialisation. (Keyword only argument.)
__call__(self, x: Array, *, key: Optional[jax.random.PRNGKey] = None) -> Array
#
Arguments:
x
: The input. Should be a JAX array of shape(in_chans, img_size[0], img_size[1])
.key
: Ignored