Skip to content

Swin Transformer V1#

eqxvision.models.SwinTransformer #

A simple port of torchvision.models.swin_transformer.

__init__(self, patch_size: List[int], embed_dim: int, depths: List[int], num_heads: List[int], window_size: List[int], mlp_ratio: float = 4.0, dropout: float = 0.0, attention_dropout: float = 0.0, stochastic_depth_prob: float = 0.1, num_classes: int = 1000, norm_layer: Callable = None, block: eqx.Module = None, downsample_layer: eqx.Module = None, *, key: Optional[jax.random.PRNGKey] = None) #

Arguments:

  • patch_size: Size of the patch to construct from the input image. Defaults to (4, 4)
  • embed_dim: The dimension of the resulting embedding of the patch
  • depths: Depth of each Swin Transformer layer
  • num_heads: Number of attention heads in different layers
  • window_size: Window size
  • mlp_ratio: Ratio of mlp hidden dim to embedding dim. Defaults to 4.0
  • dropout: Dropout rate. Defaults to 0.0
  • attention_dropout: Attention dropout rate. Defaults to 0.0
  • stochastic_depth_prob: Stochastic depth rate. Defaults to 0.1
  • num_classes: Number of classes in the classification task. Also controls the final output shape (num_classes,)
  • norm_layer: Normalisation applied to the intermediate outputs. Defaults to LayerNorm2d
  • block: The SwinTransformer-v1/v2 block to use. Defaults to _SwinTransformerBlock which is used in v1
  • downsample_layer: Downsample layer (patch merging). Defaults to _PatchMerging which is used in v1
  • key: A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)

Note

Currently, input image dimensions are required to be divisible by the window_size for each level of depth. For example, input image shape (3, 224, 224) works for window_size (7, 7) and (3, 256, 256) works for (8, 8).

__call__(self, x: Array, *, key: jax.random.PRNGKey) -> Array #

Arguments:

  • x: The input JAX array
  • key: Required parameter. Utilised by few layers such as Dropout or DropPath

eqxvision.models.swin_t(torch_weights: str = None, **kwargs: Any) -> SwinTransformer #

Constructs a swin_tiny architecture from Swin Transformer: Hierarchical Vision Transformer using Shifted Windows.

Arguments:

  • torch_weights: A Path or URL for the PyTorch weights. Defaults to None

eqxvision.models.swin_s(torch_weights: str = None, **kwargs: Any) -> SwinTransformer #

Constructs a swin_small architecture from Swin Transformer: Hierarchical Vision Transformer using Shifted Windows.

Arguments:

  • torch_weights: A Path or URL for the PyTorch weights. Defaults to None

eqxvision.models.swin_b(torch_weights: str = None, **kwargs: Any) -> SwinTransformer #

Constructs a swin_base architecture from Swin Transformer: Hierarchical Vision Transformer using Shifted Windows.

Arguments:

  • torch_weights: A Path or URL for the PyTorch weights. Defaults to None

eqxvision.models.swin_v2_t(torch_weights: str = None, **kwargs: Any) -> SwinTransformer #

Constructs a swin_v2_tiny architecture from Swin Transformer V2: Scaling Up Capacity and Resolution.

Arguments:

  • torch_weights: A Path or URL for the PyTorch weights. Defaults to None

eqxvision.models.swin_v2_s(torch_weights: str = None, **kwargs: Any) -> SwinTransformer #

Constructs a swin_v2_small architecture from Swin Transformer V2: Scaling Up Capacity and Resolution.

Arguments:

  • torch_weights: A Path or URL for the PyTorch weights. Defaults to None

eqxvision.models.swin_v2_b(torch_weights: str = None, **kwargs: Any) -> SwinTransformer #

Constructs a swin_v2_base architecture from Swin Transformer V2: Scaling Up Capacity and Resolution.

Arguments:

  • torch_weights: A Path or URL for the PyTorch weights. Defaults to None