Skip to content

Vision Transformer#

eqxvision.models.VisionTransformer #

Vision Transformer ported from https://github.com/facebookresearch/dino/blob/main/vision_transformer.py

__init__(self, img_size: Union[int, Tuple[int]] = 224, patch_size: Union[int, Tuple[int]] = 16, in_chans: int = 3, num_classes: int = 0, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, mlp_ratio: float = 4.0, qkv_bias: bool = True, qk_scale = None, drop_rate = 0.0, attn_drop_rate = 0.0, drop_path_rate = 0.0, norm_layer = <class 'equinox.nn.normalisation.LayerNorm'>, *, key: Optional[jax.random.PRNGKey] = None) #

Arguments:

  • img_size: The size of the input image. Defaults to (224, 224)
  • patch_size: Size of the patch to construct from the input image. Defaults to (16, 16)
  • in_chans: Number of input channels. Defaults to 3
  • num_classes: Number of classes in the classification task. Also controls the final output shape (num_classes,). If num_classes=0, then the final layer is replaced with nn.Identity
  • embed_dim: The dimension of the resulting embedding of the patch. Defaults to 768
  • depth: Number of VitBlocks in the network
  • num_heads: The number of equal parts to split the input along the dim
  • mlp_ratio: For computing hidden dimension of the mlp
  • qkv_bias: To add bias within the qkv computation
  • qk_scale: For scaling the query value computation
  • drop_rate: Dropout rate used within the MlpProjection
  • attn_drop_rate: Dropout rate used within the attention modules
  • drop_path_rate: Dropout rate used within VitBlocks
  • norm_layer: Normalisation applied to the intermediate outputs. Defaults to equinox.nn.LayerNorm
  • key: A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)
__call__(self, x: Array, *, key: jax.random.PRNGKey) -> Array #

Arguments:

  • x: The input JAX array
  • key: Required parameter. Utilised by few layers such as Dropout or DropPath
get_last_self_attention(self, x: Array, *, key: jax.random.PRNGKey) -> Array #

Arguments:

  • x: The input JAX array
  • key: Utilised by few layers in the network such as Dropout or BatchNorm

eqxvision.models.vit_tiny(patch_size: str = 16, embed_dim: str = 192, depth: str = 12, num_heads: str = 3, mlp_ratio: str = 4, torch_weights: str = None, *, key: Optional[jax.random.PRNGKey] = None, **kwargs) #

Arguments:

  • patch_size: Size of the patch to construct from the input image
  • embed_dim: The dimension of the resulting embedding of the patch
  • depth: Number of VitBlocks in the network
  • num_heads: The number of equal parts to split the input along the dim
  • mlp_ratio: For computing hidden dimension of the mlp
  • torch_weights: A Path or URL for the PyTorch weights. Defaults to None
  • key: A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)
  • kwargs: Parameters passed on to the VisionTransformer

eqxvision.models.vit_small(patch_size: int = 16, embed_dim: int = 384, depth: int = 12, num_heads: int = 6, mlp_ratio: int = 4, torch_weights: str = None, *, key: Optional[jax.random.PRNGKey] = None, **kwargs) #

Arguments:

  • patch_size: Size of the patch to construct from the input image
  • embed_dim: The dimension of the resulting embedding of the patch
  • depth: Number of VitBlocks in the network
  • num_heads: The number of equal parts to split the input along the dim
  • mlp_ratio: For computing hidden dimension of the mlp
  • torch_weights: A Path or URL for the PyTorch weights. Defaults to None
  • key: A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)
  • kwargs: Parameters passed on to the VisionTransformer

eqxvision.models.vit_base(patch_size: int = 16, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, mlp_ratio: int = 4, torch_weights: str = None, *, key: Optional[jax.random.PRNGKey] = None, **kwargs) #

Arguments:

  • patch_size: Size of the patch to construct from the input image
  • embed_dim: The dimension of the resulting embedding of the patch
  • depth: Number of VitBlocks in the network
  • num_heads: The number of equal parts to split the input along the dim
  • mlp_ratio: For computing hidden dimension of the mlp
  • torch_weights: A Path or URL for the PyTorch weights. Defaults to None
  • key: A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)
  • kwargs: Parameters passed on to the VisionTransformer