Swin Transformer V1#
eqxvision.models.SwinTransformer
#
A simple port of torchvision.models.swin_transformer.
__init__(self, patch_size: List[int], embed_dim: int, depths: List[int], num_heads: List[int], window_size: List[int], mlp_ratio: float = 4.0, dropout: float = 0.0, attention_dropout: float = 0.0, stochastic_depth_prob: float = 0.1, num_classes: int = 1000, norm_layer: Callable = None, block: eqx.Module = None, downsample_layer: eqx.Module = None, *, key: Optional[jax.random.PRNGKey] = None)
#
Arguments:
patch_size: Size of the patch to construct from the input image. Defaults to(4, 4)embed_dim: The dimension of the resulting embedding of the patchdepths: Depth of each Swin Transformer layernum_heads: Number of attention heads in different layerswindow_size: Window sizemlp_ratio: Ratio of mlp hidden dim to embedding dim. Defaults to4.0dropout: Dropout rate. Defaults to0.0attention_dropout: Attention dropout rate. Defaults to0.0stochastic_depth_prob: Stochastic depth rate. Defaults to0.1num_classes: Number of classes in the classification task. Also controls the final output shape(num_classes,)norm_layer: Normalisation applied to the intermediate outputs. Defaults toLayerNorm2dblock: The SwinTransformer-v1/v2 block to use. Defaults to_SwinTransformerBlockwhich is used inv1downsample_layer: Downsample layer (patch merging). Defaults to_PatchMergingwhich is used inv1key: Ajax.random.PRNGKeyused to provide randomness for parameter initialisation. (Keyword only argument.)
Note
Currently, input image dimensions are required to be divisible by the window_size for
each level of depth. For example, input image shape (3, 224, 224) works for window_size (7, 7)
and (3, 256, 256) works for (8, 8).
__call__(self, x: Array, *, key: jax.random.PRNGKey) -> Array
#
Arguments:
x: The inputJAXarraykey: Required parameter. Utilised by few layers such asDropoutorDropPath
eqxvision.models.swin_t(torch_weights: str = None, **kwargs: Any) -> SwinTransformer
#
Constructs a swin_tiny architecture from Swin Transformer: Hierarchical Vision Transformer using Shifted Windows.
Arguments:
torch_weights: APathorURLfor thePyTorchweights. Defaults toNone
eqxvision.models.swin_s(torch_weights: str = None, **kwargs: Any) -> SwinTransformer
#
Constructs a swin_small architecture from Swin Transformer: Hierarchical Vision Transformer using Shifted Windows.
Arguments:
torch_weights: APathorURLfor thePyTorchweights. Defaults toNone
eqxvision.models.swin_b(torch_weights: str = None, **kwargs: Any) -> SwinTransformer
#
Constructs a swin_base architecture from Swin Transformer: Hierarchical Vision Transformer using Shifted Windows.
Arguments:
torch_weights: APathorURLfor thePyTorchweights. Defaults toNone
eqxvision.models.swin_v2_t(torch_weights: str = None, **kwargs: Any) -> SwinTransformer
#
Constructs a swin_v2_tiny architecture from Swin Transformer V2: Scaling Up Capacity and Resolution.
Arguments:
torch_weights: APathorURLfor thePyTorchweights. Defaults toNone
eqxvision.models.swin_v2_s(torch_weights: str = None, **kwargs: Any) -> SwinTransformer
#
Constructs a swin_v2_small architecture from Swin Transformer V2: Scaling Up Capacity and Resolution.
Arguments:
torch_weights: APathorURLfor thePyTorchweights. Defaults toNone
eqxvision.models.swin_v2_b(torch_weights: str = None, **kwargs: Any) -> SwinTransformer
#
Constructs a swin_v2_base architecture from Swin Transformer V2: Scaling Up Capacity and Resolution.
Arguments:
torch_weights: APathorURLfor thePyTorchweights. Defaults toNone