Vision Transformer#
eqxvision.models.VisionTransformer
#
Vision Transformer ported from https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
__init__(self, img_size: Union[int, Tuple[int]] = 224, patch_size: Union[int, Tuple[int]] = 16, in_chans: int = 3, num_classes: int = 0, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, mlp_ratio: float = 4.0, qkv_bias: bool = True, qk_scale = None, drop_rate = 0.0, attn_drop_rate = 0.0, drop_path_rate = 0.0, norm_layer = <class 'equinox.nn.normalisation.LayerNorm'>, *, key: Optional[jax.random.PRNGKey] = None)
#
Arguments:
img_size: The size of the input image. Defaults to(224, 224)patch_size: Size of the patch to construct from the input image. Defaults to(16, 16)in_chans: Number of input channels. Defaults to3num_classes: Number of classes in the classification task. Also controls the final output shape(num_classes,). Ifnum_classes=0, then the final layer is replaced withnn.Identityembed_dim: The dimension of the resulting embedding of the patch. Defaults to768depth: Number ofVitBlocksin the networknum_heads: The number of equal parts to split the input along thedimmlp_ratio: For computing hidden dimension of themlpqkv_bias: To addbiaswithin theqkvcomputationqk_scale: For scaling thequeryvaluecomputationdrop_rate: Dropout rate used within theMlpProjectionattn_drop_rate: Dropout rate used within the attention modulesdrop_path_rate: Dropout rate used withinVitBlocksnorm_layer: Normalisation applied to the intermediate outputs. Defaults toequinox.nn.LayerNormkey: Ajax.random.PRNGKeyused to provide randomness for parameter initialisation. (Keyword only argument.)
__call__(self, x: Array, *, key: jax.random.PRNGKey) -> Array
#
Arguments:
x: The inputJAXarraykey: Required parameter. Utilised by few layers such asDropoutorDropPath
get_last_self_attention(self, x: Array, *, key: jax.random.PRNGKey) -> Array
#
Arguments:
x: The inputJAXarraykey: Utilised by few layers in the network such asDropoutorBatchNorm
eqxvision.models.vit_tiny(patch_size: str = 16, embed_dim: str = 192, depth: str = 12, num_heads: str = 3, mlp_ratio: str = 4, torch_weights: str = None, *, key: Optional[jax.random.PRNGKey] = None, **kwargs)
#
Arguments:
patch_size: Size of the patch to construct from the input imageembed_dim: The dimension of the resulting embedding of the patchdepth: Number ofVitBlocksin the networknum_heads: The number of equal parts to split the input along thedimmlp_ratio: For computing hidden dimension of themlptorch_weights: APathorURLfor thePyTorchweights. Defaults toNonekey: Ajax.random.PRNGKeyused to provide randomness for parameter initialisation. (Keyword only argument.)kwargs: Parameters passed on to theVisionTransformer
eqxvision.models.vit_small(patch_size: int = 16, embed_dim: int = 384, depth: int = 12, num_heads: int = 6, mlp_ratio: int = 4, torch_weights: str = None, *, key: Optional[jax.random.PRNGKey] = None, **kwargs)
#
Arguments:
patch_size: Size of the patch to construct from the input imageembed_dim: The dimension of the resulting embedding of the patchdepth: Number ofVitBlocksin the networknum_heads: The number of equal parts to split the input along thedimmlp_ratio: For computing hidden dimension of themlptorch_weights: APathorURLfor thePyTorchweights. Defaults toNonekey: Ajax.random.PRNGKeyused to provide randomness for parameter initialisation. (Keyword only argument.)kwargs: Parameters passed on to theVisionTransformer
eqxvision.models.vit_base(patch_size: int = 16, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, mlp_ratio: int = 4, torch_weights: str = None, *, key: Optional[jax.random.PRNGKey] = None, **kwargs)
#
Arguments:
patch_size: Size of the patch to construct from the input imageembed_dim: The dimension of the resulting embedding of the patchdepth: Number ofVitBlocksin the networknum_heads: The number of equal parts to split the input along thedimmlp_ratio: For computing hidden dimension of themlptorch_weights: APathorURLfor thePyTorchweights. Defaults toNonekey: Ajax.random.PRNGKeyused to provide randomness for parameter initialisation. (Keyword only argument.)kwargs: Parameters passed on to theVisionTransformer